From 51086d2b8d36ef6aa513cb017ed6361934950066 Mon Sep 17 00:00:00 2001 From: Chenjie Duan <44265800+kilinchange@users.noreply.github.com> Date: Mon, 15 Jan 2024 11:02:13 +0800 Subject: [PATCH] Modify kernel registration & support fp16 (#205) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * - Remove dataType from the kernel registration. * - support fp16 for conv * - cpu kernel: adapt the new registration mechanism * modified all register kernel * add where fp16 * add layernorm fp16 * add split_concat fp16 * - element_wise support fp16 * feat: support transpose fp16 * feat: support sliceOp fp16 * - unary support fp16 * - feat: support reduceOp fp16 * feat: support matmulOp/expandOp fp16 * feat: support powOp int8 * add cuda cast & support half-precision for gather * style: fix style * feat:support int8 for gather * style:fix style * modified test_cuda_conv_transposed * fix: fix dist code to support fp16 * fix(graph.cc): fix topo_sort * fix: fix recv and send kernel registration * feat: add field tensors for stub * refactor(frontend): 先排序后构图 Signed-off-by: YdrMaster * fix: 为中间结果提供tensor到node的mapping * fix (slice): add guard for area out of range * fix: fix matmul fp16 * fix: fix re-dataMalloc for weight tensor and use of naive allocator * feat: add dataType filter for cuda kernel * feat: bang kernel adapt the new registration mechanism * fix: fix some error on mlu * feat: intelcpu kernel adapt the new registration mechanism * feat: modify kernel registration on kunlun * fix intelcpu compiler bug * feat: bang reshape support all dataType * fix: fix bang reduce * fix(all_reduce.cc): fix as reviewer suggessted * fix: fix style and restore unary test codes --------- Signed-off-by: YdrMaster Co-authored-by: xgqdut2016 Co-authored-by: xgqdut2016 <140036308+xgqdut2016@users.noreply.github.com> Co-authored-by: zhangyunze Co-authored-by: OdinaryWord Co-authored-by: YdrMaster Co-authored-by: panzezhong --- examples/distributed/parallel_opt.py | 11 +- include/core/kernel.h | 12 +- include/core/operator.h | 3 +- include/core/tensor.h | 12 +- include/cuda/cuda_element_wise.h | 21 +- include/cuda/cuda_expand.h | 5 +- include/cuda/cuda_layernorm.h | 6 + include/cuda/cuda_softmax.h | 4 +- include/cuda/cuda_split_concat.h | 17 +- include/cuda/cuda_transpose.h | 2 +- include/cuda/cuda_unary.h | 55 +- include/cuda/cuda_utility.h | 26 +- include/cuda/cuda_where.h | 7 +- include/cuda/gather.h | 3 +- include/utils/data_generator.h | 6 + pyinfinitensor/src/pyinfinitensor/onnx.py | 1642 ++++++++--------- src/bang/bang_runtime.cc | 3 +- src/core/graph.cc | 52 +- src/core/runtime.cc | 6 +- src/cuda/cuda_runtime.cc | 6 +- src/cuda/cuda_utility.cu | 53 + src/kernels/bang/activation.cc | 21 +- src/kernels/bang/activation_backward.cc | 13 +- src/kernels/bang/batchnorm.cc | 5 +- src/kernels/bang/cast.cc | 3 +- src/kernels/bang/ceil.cc | 4 +- src/kernels/bang/clip.cc | 4 +- src/kernels/bang/concat.cc | 4 +- src/kernels/bang/conv.cc | 4 +- src/kernels/bang/conv_trans.cc | 5 +- src/kernels/bang/convbpfilter.cc | 5 +- src/kernels/bang/det.cc | 4 +- src/kernels/bang/element_wise.cc | 105 +- src/kernels/bang/erf.cc | 4 +- src/kernels/bang/exp.cc | 4 +- src/kernels/bang/fill.cc | 4 +- src/kernels/bang/floor.cc | 3 +- src/kernels/bang/gather.cc | 4 +- src/kernels/bang/hardtanh.cc | 5 +- src/kernels/bang/l2loss.cc | 4 +- src/kernels/bang/layer_norm.cc | 5 +- src/kernels/bang/log.cc | 4 +- src/kernels/bang/lrn.cc | 4 +- src/kernels/bang/matmul.cc | 4 +- src/kernels/bang/negtensor.cc | 4 +- src/kernels/bang/pad.cc | 4 +- src/kernels/bang/pooling.cc | 9 +- src/kernels/bang/reciprocal.cc | 5 +- src/kernels/bang/reduce.cc | 9 +- src/kernels/bang/reshape.cc | 17 +- src/kernels/bang/rsqrt.cc | 4 +- src/kernels/bang/split.cc | 4 +- src/kernels/bang/sqrt.cc | 4 +- src/kernels/bang/transpose.cc | 10 +- src/kernels/bang/trigon.cc | 37 +- src/kernels/bang/where.cc | 4 +- src/kernels/cpu/concat.cc | 28 +- src/kernels/cpu/conv.cc | 28 +- src/kernels/cpu/element_wise.cc | 173 +- src/kernels/cpu/matmul.cc | 28 +- src/kernels/cpu/membound.cc | 4 +- src/kernels/cpu/pooling.cc | 117 +- src/kernels/cpu/split.cc | 27 +- src/kernels/cpu/transpose.cc | 29 +- src/kernels/cpu/unary.cc | 396 ++-- src/kernels/cuda/G2BMM.cc | 4 +- src/kernels/cuda/GBMM.cc | 4 +- src/kernels/cuda/all_gather.cc | 4 +- src/kernels/cuda/all_reduce.cc | 35 +- src/kernels/cuda/attention_kvcache.cc | 5 +- src/kernels/cuda/batch_norm.cc | 5 +- src/kernels/cuda/broadcast.cc | 4 +- src/kernels/cuda/clip.cc | 5 +- src/kernels/cuda/conv.cc | 93 +- src/kernels/cuda/conv_half.cc | 261 --- src/kernels/cuda/conv_transposed.cc | 10 +- src/kernels/cuda/element_wise.cc | 71 +- src/kernels/cuda/element_wise.cu | 102 +- src/kernels/cuda/expand.cc | 6 +- src/kernels/cuda/expand.cu | 61 +- src/kernels/cuda/extend.cc | 4 +- src/kernels/cuda/gather.cc | 21 +- src/kernels/cuda/gather.cu | 21 +- src/kernels/cuda/gather_elements.cc | 7 +- src/kernels/cuda/layer_norm.cc | 43 +- src/kernels/cuda/layer_norm.cu | 297 ++- src/kernels/cuda/matmul.cc | 62 +- .../cuda/membound_tvm_extract_source.cc | 5 +- .../cuda/membound_tvm_packed_function.cc | 4 +- src/kernels/cuda/pad_slice.cc | 10 +- src/kernels/cuda/pad_slice.cu | 81 +- src/kernels/cuda/pooling.cc | 10 +- src/kernels/cuda/recv.cc | 3 +- src/kernels/cuda/reduce.cc | 29 +- src/kernels/cuda/reshape.cc | 21 +- src/kernels/cuda/resize.cc | 4 +- src/kernels/cuda/send.cc | 3 +- src/kernels/cuda/softmax.cc | 14 +- src/kernels/cuda/softmax.cu | 270 ++- src/kernels/cuda/split_concat.cc | 54 +- src/kernels/cuda/split_concat.cu | 34 +- src/kernels/cuda/transpose.cc | 20 +- src/kernels/cuda/transpose.cu | 60 +- src/kernels/cuda/unary.cc | 111 +- src/kernels/cuda/unary.cu | 209 ++- src/kernels/cuda/where.cc | 20 +- src/kernels/cuda/where.cu | 40 +- src/kernels/intelcpu/batch_norm.cc | 5 +- src/kernels/intelcpu/concat.cc | 4 +- src/kernels/intelcpu/conv.cc | 4 +- src/kernels/intelcpu/conv_transposed.cc | 5 +- src/kernels/intelcpu/element_wise.cc | 26 +- src/kernels/intelcpu/extend.cc | 4 +- src/kernels/intelcpu/gather.cc | 4 +- src/kernels/intelcpu/matmul.cc | 3 +- src/kernels/intelcpu/matmul_dpcpp.cc | 5 +- src/kernels/intelcpu/pad.cc | 4 +- src/kernels/intelcpu/pooling.cc | 8 +- src/kernels/intelcpu/pow.cc | 4 +- src/kernels/intelcpu/reduce.cc | 7 +- src/kernels/intelcpu/reshape.cc | 11 +- src/kernels/intelcpu/resize.cc | 4 +- src/kernels/intelcpu/slice.cc | 4 +- src/kernels/intelcpu/softmax.cc | 4 +- src/kernels/intelcpu/split.cc | 4 +- src/kernels/kunlun/batch_norm.cc | 5 +- src/kernels/kunlun/cast.cc | 3 +- src/kernels/kunlun/concat.cc | 5 +- src/kernels/kunlun/conv.cc | 4 +- src/kernels/kunlun/conv_trans.cc | 9 +- src/kernels/kunlun/element_wise.cc | 78 +- src/kernels/kunlun/gather.cc | 5 +- src/kernels/kunlun/matmul.cc | 5 +- src/kernels/kunlun/pad.cc | 4 +- src/kernels/kunlun/pooling.cc | 9 +- src/kernels/kunlun/reduce_mean.cc | 5 +- src/kernels/kunlun/select.cc | 4 +- src/kernels/kunlun/softmax.cc | 5 +- src/kernels/kunlun/split.cc | 4 +- src/kernels/kunlun/transpose.cc | 10 +- src/kernels/kunlun/unary.cc | 116 +- src/kunlun/kunlun_runtime.cc | 3 +- src/operators/layer_norm.cc | 5 +- src/utils/operator_utils.cc | 3 +- test/kernels/cuda/test_cuda_concat.cc | 38 + .../cuda/test_cuda_conv_transposed_2d.cc | 4 +- test/kernels/cuda/test_cuda_layernorm.cc | 90 +- test/kernels/cuda/test_cuda_softmax.cc | 227 ++- test/kernels/cuda/test_cuda_split.cc | 30 + test/kernels/cuda/test_cuda_unary.cc | 30 + test/kernels/cuda/test_cuda_where.cc | 96 +- test/kernels/intelcpu/test_mkl_conv.cc | 4 +- .../intelcpu/test_mkl_conv_transposed.cc | 4 +- test/kernels/intelcpu/test_mkl_pooling.cc | 2 +- test/kernels/intelcpu/test_mkl_reduce.cc | 2 +- test/operators/test_unary.cc | 3 +- test/operators/test_where.cc | 38 +- 157 files changed, 3627 insertions(+), 2575 deletions(-) delete mode 100644 src/kernels/cuda/conv_half.cc diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index 1214b6b3..bbb0ac65 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -137,7 +137,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): place[node.output[0]] = Shard(list(perm).index(plc.dim)) def shard_node(node: NodeProto): - if node.op_type in ["Relu", "Tanh", "Softmax"]: + if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]: place[node.output[0]] = place[node.input[0]] elif node.op_type in ["Where"]: place[node.output[0]] = place[node.input[1]] @@ -177,7 +177,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): input in data for input in node.input ): # FIXME(constroy): the last MatMul should not be sharded as TP. - if node.output[0] in output: + if ( + node.output[0] in output + or ( + index + 1 < len(model.graph.node) + and model.graph.node[index + 1].output[0] + ) + in output + ): continue groups = 1 # If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups diff --git a/include/core/kernel.h b/include/core/kernel.h index a19f3f1a..76189599 100644 --- a/include/core/kernel.h +++ b/include/core/kernel.h @@ -30,7 +30,6 @@ class Kernel { public: Kernel() {} virtual ~Kernel() {} - /** * @param op The operator to be executed. * @param record The parameters for kernel execution. If extra parameters @@ -130,15 +129,16 @@ class CpuKernelWithoutConfig : public Kernel { } // namespace infini -#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \ +#define _REGISTER_KERNEL_1(device, opType, kernel, name, cnt) \ namespace infini { \ static const bool _CAT(_register_kernel_, cnt) = \ - KernelRegistry::getInstance().registerKernel( \ - KernelAttrs{device, opType, dataType}, new kernel(), name); \ + KernelRegistry::getInstance().registerKernel(KernelAttrs{device, \ + opType}, \ + new kernel(), name); \ } -#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \ - _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__) +#define REGISTER_KERNEL(device, opType, kernel, name) \ + _REGISTER_KERNEL_1(device, opType, kernel, name, __COUNTER__) #define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \ namespace infini { \ diff --git a/include/core/operator.h b/include/core/operator.h index cc8ce174..220a06c1 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -4,7 +4,7 @@ #include "core/tensor.h" namespace infini { -using KernelAttrs = std::tuple; +using KernelAttrs = std::tuple; struct OpPerfKey { HashType hash; @@ -90,6 +90,7 @@ class OperatorObj : public Object { OpType getOpType() const { return type; } // HACK: set correct data type DataType getDType() const { return getInputs(0)->getDType(); } + DataType getOutDType() const { return getOutput()->getDType(); } virtual int numInputs() const = 0; virtual int numOutputs() const = 0; diff --git a/include/core/tensor.h b/include/core/tensor.h index 95229c14..63efd0f7 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -44,8 +44,16 @@ class TensorObj : public TensorBaseObj { bool isOutput() const { return tensorType == TensorType::output; } bool isOthers() const { return tensorType == TensorType::others; } void setWeight() { tensorType = TensorType::weight; } - void setInput() { tensorType = TensorType::input; } - void setOutput() { tensorType = TensorType::output; } + void setInput() { + if (!this->isWeight()) { + tensorType = TensorType::input; + } + } + void setOutput() { + if (!this->isWeight()) { + tensorType = TensorType::output; + } + } string tensorTypeToString() const { switch (tensorType) { case TensorType::weight: diff --git a/include/cuda/cuda_element_wise.h b/include/cuda/cuda_element_wise.h index db9c16f1..10bb1bca 100644 --- a/include/cuda/cuda_element_wise.h +++ b/include/cuda/cuda_element_wise.h @@ -1,13 +1,16 @@ #pragma once namespace infini { -void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, - int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); -void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, - int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); -void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, - int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); -void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, - int b0, int b1, int b2, int b3, int c0, int c1, int c2, - int c3); +void div_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1, + int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, + int c2, int c3); +void add_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1, + int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, + int c2, int c3); +void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1, + int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, + int c2, int c3); +void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1, + int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, + int c2, int c3); }; // namespace infini diff --git a/include/cuda/cuda_expand.h b/include/cuda/cuda_expand.h index 8d4701fd..3723a8e7 100644 --- a/include/cuda/cuda_expand.h +++ b/include/cuda/cuda_expand.h @@ -3,7 +3,8 @@ #include "operators/unary.h" #include "utils/small_array.h" namespace infini { -void expandKernel(float *input, float *output, int nDims, int outputsize, - SmallArray inputShape, SmallArray outputShape); +void expandKernel(int dType, void *input, void *output, int nDims, + int outputsize, SmallArray inputShape, + SmallArray outputShape); }; // namespace infini diff --git a/include/cuda/cuda_layernorm.h b/include/cuda/cuda_layernorm.h index 997c8a06..b6829d09 100644 --- a/include/cuda/cuda_layernorm.h +++ b/include/cuda/cuda_layernorm.h @@ -8,4 +8,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps, void LaynormKernel(const float *input, const float *scale, const float eps, int size, int scaleSize, const int dimsize, const int stride, float *output); +void LaynormKernel(const half *input, const half *scale, const half eps, + int size, int scaleSize, const int dimsize, const int stride, + half *output, const half *bias, int biasSize); +void LaynormKernel(const half *input, const half *scale, const half eps, + int size, int scaleSize, const int dimsize, const int stride, + half *output); }; // namespace infini diff --git a/include/cuda/cuda_softmax.h b/include/cuda/cuda_softmax.h index 671f46f8..c0a54a34 100644 --- a/include/cuda/cuda_softmax.h +++ b/include/cuda/cuda_softmax.h @@ -3,4 +3,6 @@ namespace infini { void softmax_kernel(int num_blocks, float *input, float *output, int size, int dimsize, int stride); -} +void softmax_kernel(int num_blocks, half *input, half *output, int size, + int dimsize, int stride); +} // namespace infini diff --git a/include/cuda/cuda_split_concat.h b/include/cuda/cuda_split_concat.h index 58bdf330..d324a3ef 100644 --- a/include/cuda/cuda_split_concat.h +++ b/include/cuda/cuda_split_concat.h @@ -8,8 +8,8 @@ const int DIM_MAX_SIZE = 8; // Concat operator acts like element tensors composing to one big tensor,and // split operator acts like one big tensor being composed by element // tensors. -struct ElementTensorMetadata { - float *data[BATCH_SIZE]; +template struct ElementTensorMetadata { + T *data[BATCH_SIZE]; int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in // the composed tensor. int dimSize[BATCH_SIZE]; // the dimention size of the element tensor. @@ -20,16 +20,17 @@ struct ElementTensorMetadata { data[i], dimBgNo[i], dimSize[i], nElements[i]); } }; - -struct ComposedTensorMetadata { +template struct ComposedTensorMetadata { int dimSize[DIM_MAX_SIZE]; int stride[DIM_MAX_SIZE]; - float *data; + T *data; }; namespace infini { -void split_concat_kernel(const ElementTensorMetadata &eleMeta, - const ComposedTensorMetadata &compMeta, int dim, +void split_concat_kernel(const ElementTensorMetadata &eleMeta, + const ComposedTensorMetadata &compMeta, int dim, + int batchSize, int nDims, bool isSplit); +void split_concat_kernel(const ElementTensorMetadata &eleMeta, + const ComposedTensorMetadata &compMeta, int dim, int batchSize, int nDims, bool isSplit); - } // namespace infini diff --git a/include/cuda/cuda_transpose.h b/include/cuda/cuda_transpose.h index b168cf0e..89d080ed 100644 --- a/include/cuda/cuda_transpose.h +++ b/include/cuda/cuda_transpose.h @@ -5,7 +5,7 @@ namespace infini { -void transpose_kernel(float *input, float *output, int nDims, int size, +void transpose_kernel(int dType, void *input, void *output, int nDims, int size, SmallArray strides, SmallArray outputShape); }; // namespace infini diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index 31a39951..49a589b3 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -3,48 +3,21 @@ #include "operators/unary.h" namespace infini { -void softmax_kernel(float *input, float *output, size_t num); -void relu_kernel(float *input, float *output, size_t num); -void sigmoid_kernel(float *input, float *output, size_t num); -void tanh_kernel(float *input, float *output, size_t num); -void abs_kernel(float *input, float *output, size_t num); -void sqrt_kernel(float *input, float *output, size_t num); -void neg_kernel(float *input, float *output, size_t num); -void gelu_kernel(float *input, float *output, size_t num); -void erf_kernel(float *input, float *output, size_t num); -void hard_sigmoid_kernel(float *input, float *output, size_t num); -void hard_swish_kernel(float *input, float *output, size_t num); +template void softmax_kernel(T *input, T *output, size_t num); +template void relu_kernel(T *input, T *output, size_t num); +template void sigmoid_kernel(T *input, T *output, size_t num); +template void tanh_kernel(T *input, T *output, size_t num); +template void abs_kernel(T *input, T *output, size_t num); +template void sqrt_kernel(T *input, T *output, size_t num); +template void neg_kernel(T *input, T *output, size_t num); +template void gelu_kernel(T *input, T *output, size_t num); +template void erf_kernel(T *input, T *output, size_t num); +template void hard_sigmoid_kernel(T *input, T *output, size_t num); +template void hard_swish_kernel(T *input, T *output, size_t num); -void unary_kernel(const Operator &_op) { - auto op = as(_op); - float *const inputData = (op->getInputs(0)->getRawDataPtr()); - float *const outputData = (op->getOutput()->getRawDataPtr()); +template +void cast_kernel(INPUT *input, OUTPUT *output, size_t num); - size_t num = op->getOutput()->size(); - if (op->getOpType() == OpType::Softmax) - softmax_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Relu) - relu_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Sigmoid) - sigmoid_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::HardSigmoid) - hard_sigmoid_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::HardSwish) - hard_swish_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Tanh) - tanh_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Abs) - abs_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Sqrt) - sqrt_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Gelu) - gelu_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Neg) - neg_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Erf) - erf_kernel(inputData, outputData, num); - else - IT_TODO_HALT(); -} +void unary_kernel(const Operator &_op); }; // namespace infini diff --git a/include/cuda/cuda_utility.h b/include/cuda/cuda_utility.h index 85e3478b..bc340912 100644 --- a/include/cuda/cuda_utility.h +++ b/include/cuda/cuda_utility.h @@ -1,11 +1,29 @@ +#pragma once #include "core/tensor.h" +#include "cuda/cuda_common.h" namespace infini { void cudaPrintFloat(float *x, int len); -void cudaPrintTensor(const Tensor &tensor) { - cudaPrintFloat(tensor->getRawDataPtr(), tensor->size()); -} +void cudaPrintTensor(const Tensor &tensor); -} // namespace infini \ No newline at end of file +cudnnDataType_t cudnnDataTypeConvert(DataType dataType); +cudaDataType cublasDataTypeConvert(DataType); + +template struct DT_CUDA {}; +template <> struct DT_CUDA<0> { using t = bool; }; +template <> struct DT_CUDA<1> { using t = float; }; +template <> struct DT_CUDA<2> { using t = unsigned char; }; +template <> struct DT_CUDA<3> { using t = char; }; +template <> struct DT_CUDA<4> { using t = unsigned short; }; +template <> struct DT_CUDA<5> { using t = short; }; +template <> struct DT_CUDA<6> { using t = int; }; +template <> struct DT_CUDA<7> { using t = long long; }; +template <> struct DT_CUDA<9> { using t = bool; }; +template <> struct DT_CUDA<10> { using t = half; }; +template <> struct DT_CUDA<11> { using t = double; }; +template <> struct DT_CUDA<12> { using t = unsigned int; }; +template <> struct DT_CUDA<13> { using t = unsigned long long; }; +template <> struct DT_CUDA<16> { using t = nv_bfloat16; }; +} // namespace infini diff --git a/include/cuda/cuda_where.h b/include/cuda/cuda_where.h index bc6d3e81..8c2ba2db 100644 --- a/include/cuda/cuda_where.h +++ b/include/cuda/cuda_where.h @@ -3,10 +3,15 @@ #include "utils/small_array.h" namespace infini { + void whereKernel(const float *inputX, const float *inputY, const uint8_t *condition, float *output, int nDims, int outputsize, SmallArray inputXShape, SmallArray inputYShape, SmallArray conditionShape, SmallArray outputShape, int xSize, int ySize, int cSize); - +void whereKernel(const half *inputX, const half *inputY, + const uint8_t *condition, half *output, int nDims, + int outputsize, SmallArray inputXShape, SmallArray inputYShape, + SmallArray conditionShape, SmallArray outputShape, int xSize, + int ySize, int cSize); }; // namespace infini diff --git a/include/cuda/gather.h b/include/cuda/gather.h index 0f0a1b27..bea716c0 100644 --- a/include/cuda/gather.h +++ b/include/cuda/gather.h @@ -53,7 +53,8 @@ inline void initGatherMetaData(GatherMetaData &metaData, metaData.inStride[i] = in->getStride()[i]; } } -void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num); +template +void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num); void gather_elements_kernel(void *in, void *out, GatherMetaData metaData, size_t num); diff --git a/include/utils/data_generator.h b/include/utils/data_generator.h index 982db835..970b8038 100644 --- a/include/utils/data_generator.h +++ b/include/utils/data_generator.h @@ -91,6 +91,12 @@ template class ValGenerator : public DataGenerator { fill(data, size); } void fill(float *data, size_t size) override { fill(data, size); } + void fill_fp16(uint16_t *data, size_t size) { + for (size_t i = 0; i < size; i++) { + float x = 1.0f * val; + data[i] = float_to_fp16(x); + } + } }; typedef ValGenerator<1> OneGenerator; typedef ValGenerator<0> ZeroGenerator; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 192e5273..79abb7f4 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -37,7 +37,7 @@ class OnnxStub: It can be generated from an Onnx model object. """ - def __init__(self, model: ModelProto, runtime): + def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False): # We use some user-defined operators for distributed inference try: # onnx simplifier performs inplace simplify @@ -51,13 +51,43 @@ class OnnxStub: self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} + self.tensors: Dict[str, backend.Tensor] = {} + self.tensor_node_map: Dict[str, str] = {} self.initializer: Dict[int, TensorProto] = {} + self.use_naive_allocator: bool = use_naive_allocator # try: # model = infer_shapes(model) # except: # warnings.warn("infer_shapes failed.") self.handler = backend.GraphHandler(runtime) + # 处理重名和匿名算子 + names = {} + for node in model.graph.node: + if node.name == "": + node.name = "missing_name(" + node.op_type + ")" + if node.name in names: + names[node.name] += 1 + node.name += "_" + str(names[node.name]) + else: + names[node.name] = 0 + # 拓扑排序 + sorted_nodes = [] + known_edge = set(t.name for t in model.graph.input) + known_edge.update(t.name for t in model.graph.initializer) + while len(sorted_nodes) < len(model.graph.node): + updated = False + for i, node in enumerate(model.graph.node): + if all(t in known_edge for t in node.input): + node.name = str(len(sorted_nodes)) + "_" + node.name + sorted_nodes.append(i) + known_edge.update(node.output) + for t_ in node.output: + self.tensor_node_map[t_] = node.name + updated = True + if not updated: + raise Exception("Graph has cycle") + tensors: Dict[str, backend.Tensor] = dict() data: Dict[str, TensorProto] = dict() @@ -82,98 +112,64 @@ class OnnxStub: ) tensors[output.name].set_output() - node_name = [] - new_node_name = [] - for node in model.graph.node: - node_name.append(node.name) - node_list = model.graph.node - while len(node_list) != 0: - for node in model.graph.node: - if node.name not in node_list: - continue - if _analyse_node(node, tensors): - continue - if node.op_type == "Conv": - attributes = _parse_attribute( - node, - { - "dilations": [1, 1], - "pads": [0, 0, 0, 0], - "strides": [1, 1], - }, + for node_idx in sorted_nodes: + node = model.graph.node[node_idx] + if node.op_type == "Conv": + attributes = _parse_attribute( + node, + { + "dilations": [1, 1], + "pads": [0, 0, 0, 0], + "strides": [1, 1], + }, + ) + (d, p, s) = ( + attributes[name] for name in ["dilations", "pads", "strides"] + ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors[node.input[0]], None, p, [-2, -1] ) - (d, p, s) = ( - attributes[name] for name in ["dilations", "pads", "strides"] - ) - if p[0] != p[2] or p[1] != p[3]: - adapt = "{}-adapt".format(node.output[0]) - tensors[adapt] = self.handler.pad( - tensors[node.input[0]], None, p, [-2, -1] - ) - p = [0, 0, 0, 0] - else: - adapt = node.input[0] + p = [0, 0, 0, 0] + else: + adapt = node.input[0] - if len(node.input) > 2: - bias = "{}-bias".format(node.output[0]) - reshape = "{}-reshape".format(node.output[0]) - tensors[bias] = self.handler.conv( - tensors[adapt], - tensors[node.input[1]], - None, - p[0], - p[1], - s[0], - s[1], - d[0], - d[1], - ) - tensors[reshape] = self.handler.reshape( - tensors[node.input[2]], - None, - [ - 1, - reduce( - lambda acc, x: acc * x, - tensors[node.input[2]].shape(), - ), - 1, - 1, - ], - ) - tensors[node.output[0]] = self.handler.add( - tensors[bias], - tensors[reshape], - tensors.get(node.output[0]), - ) - else: - tensors[node.output[0]] = self.handler.conv( - tensors[adapt], - tensors[node.input[1]], - tensors.get(node.output[0]), - p[0], - p[1], - s[0], - s[1], - d[0], - d[1], - ) - elif node.op_type == "ConvTranspose": - attributes = _parse_attribute( - node, - { - "dilations": [1, 1], - "pads": [0, 0], - "strides": [1, 1], - "output_padding": [0, 0], - }, + if len(node.input) > 2: + bias = "{}-bias".format(node.output[0]) + reshape = "{}-reshape".format(node.output[0]) + tensors[bias] = self.handler.conv( + tensors[adapt], + tensors[node.input[1]], + None, + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], ) - (d, p, s, op) = ( - attributes[name] - for name in ["dilations", "pads", "strides", "output_padding"] + tensors[reshape] = self.handler.reshape( + tensors[node.input[2]], + None, + [ + 1, + reduce( + lambda acc, x: acc * x, + tensors[node.input[2]].shape(), + ), + 1, + 1, + ], ) - tensors[node.output[0]] = self.handler.convTransposed2d( - tensors[node.input[0]], + tensors[node.output[0]] = self.handler.add( + tensors[bias], + tensors[reshape], + tensors.get(node.output[0]), + ) + else: + tensors[node.output[0]] = self.handler.conv( + tensors[adapt], tensors[node.input[1]], tensors.get(node.output[0]), p[0], @@ -182,540 +178,632 @@ class OnnxStub: s[1], d[0], d[1], - op[0], - op[1], ) - elif node.op_type == "MatMul": - tensors[node.output[0]] = self.handler.matmul( - tensors[node.input[0]], - tensors[node.input[1]], + elif node.op_type == "ConvTranspose": + attributes = _parse_attribute( + node, + { + "dilations": [1, 1], + "pads": [0, 0], + "strides": [1, 1], + "output_padding": [0, 0], + }, + ) + (d, p, s, op) = ( + attributes[name] + for name in ["dilations", "pads", "strides", "output_padding"] + ) + tensors[node.output[0]] = self.handler.convTransposed2d( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], + op[0], + op[1], + ) + elif node.op_type == "MatMul": + tensors[node.output[0]] = self.handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + False, + False, + None, + backend.ActType.Linear, + ) + elif node.op_type == "Gemm": + attributes = _parse_attribute( + node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} + ) + (alpha, beta, transA, transB) = ( + attributes[name] for name in ["alpha", "beta", "transA", "transB"] + ) + # FIXME unsupport attributes: `alpha` `beta` + assert alpha == 1.0 + assert beta == 1.0 + tensors[node.output[0]] = self.handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + transA == 1, + transB == 1, + tensors[node.input[2]] if len(node.input) > 2 else None, + backend.ActType.Linear, + ) + elif node.op_type == "BatchNormalization": + (input, mean, var, scale, bias) = ( + tensors[node.input[i]] for i in [0, 3, 4, 1, 2] + ) + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} + ) + (momentum, eps, training) = ( + attributes[name] + for name in ["momentum", "epsilon", "training_mode"] + ) + tensors[node.output[0]] = self.handler.batchNormalization( + input, + output, + mean, + var, + scale, + bias, + momentum, + eps, + training != 0, + ) + elif node.op_type == "LayerNormalization": + (input, scale) = (tensors[node.input[i]] for i in [0, 1]) + bias = None if len(node.input) < 3 else tensors[node.input[2]] + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, {"axis": -1, "epsilon": 1e-05, "stash_type": 1} + ) + (axis, eps, stash_type) = ( + attributes[name] for name in ["axis", "epsilon", "stash_type"] + ) + tensors[node.output[0]] = self.handler.layerNormalization( + input, + scale, + output, + bias, + eps, + axis, + stash_type, + ) + elif node.op_type == "MaxPool": + attributes = _parse_attribute( + node, + { + "kernel_shape": None, + "dilations": [1, 1], + "pads": [0, 0, 0, 0], + "strides": [1, 1], + "ceil_mode": 0, + }, + ) + (k, d, p, s, ceil_mode) = ( + attributes[name] + for name in [ + "kernel_shape", + "dilations", + "pads", + "strides", + "ceil_mode", + ] + ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors.get(node.input[0]), None, p, [-2, -1] + ) + tensors[node.output[0]] = self.handler.maxPool( + tensors[adapt], tensors.get(node.output[0]), - False, - False, - None, - backend.ActType.Linear, + k[0], + k[1], + d[0], + d[1], + 0, + 0, + s[0], + s[1], + ceil_mode, ) - elif node.op_type == "Gemm": - attributes = _parse_attribute( - node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} - ) - (alpha, beta, transA, transB) = ( - attributes[name] - for name in ["alpha", "beta", "transA", "transB"] - ) - # FIXME unsupport attributes: `alpha` `beta` - assert alpha == 1.0 - assert beta == 1.0 - tensors[node.output[0]] = self.handler.matmul( + else: + tensors[node.output[0]] = self.handler.maxPool( tensors[node.input[0]], - tensors[node.input[1]], tensors.get(node.output[0]), - transA == 1, - transB == 1, - tensors[node.input[2]] if len(node.input) > 2 else None, - backend.ActType.Linear, + k[0], + k[1], + d[0], + d[1], + p[0], + p[1], + s[0], + s[1], + ceil_mode, ) - elif node.op_type == "BatchNormalization": - (input, mean, var, scale, bias) = ( - tensors[node.input[i]] for i in [0, 3, 4, 1, 2] + elif node.op_type == "AveragePool": + attributes = _parse_attribute( + node, + { + "kernel_shape": None, + "pads": [0, 0, 0, 0], + "strides": [1, 1], + "ceil_mode": 0, + }, + ) + (k, p, s, ceil_mode) = ( + attributes[name] + for name in ["kernel_shape", "pads", "strides", "ceil_mode"] + ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors.get(node.input[0]), None, p, [-2, -1] ) - output = tensors.get(node.output[0]) - attributes = _parse_attribute( - node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} + tensors[node.output[0]] = self.handler.avgPool( + tensors[adapt], + tensors.get(node.output[0]), + k[0], + k[1], + 1, + 1, + 0, + 0, + s[0], + s[1], + ceil_mode, ) - (momentum, eps, training) = ( - attributes[name] - for name in ["momentum", "epsilon", "training_mode"] - ) - tensors[node.output[0]] = self.handler.batchNormalization( - input, - output, - mean, - var, - scale, - bias, - momentum, - eps, - training != 0, - ) - elif node.op_type == "LayerNormalization": - (input, scale) = (tensors[node.input[i]] for i in [0, 1]) - bias = None if len(node.input) < 3 else tensors[node.input[2]] - output = tensors.get(node.output[0]) - attributes = _parse_attribute( - node, {"axis": -1, "epsilon": 1e-05, "stash_type": 1} - ) - (axis, eps, stash_type) = ( - attributes[name] for name in ["axis", "epsilon", "stash_type"] - ) - tensors[node.output[0]] = self.handler.layerNormalization( - input, - scale, - output, - bias, - eps, - axis, - stash_type, - ) - elif node.op_type == "MaxPool": - attributes = _parse_attribute( - node, - { - "kernel_shape": None, - "dilations": [1, 1], - "pads": [0, 0, 0, 0], - "strides": [1, 1], - "ceil_mode": 0, - }, - ) - (k, d, p, s, ceil_mode) = ( - attributes[name] - for name in [ - "kernel_shape", - "dilations", - "pads", - "strides", - "ceil_mode", - ] - ) - if p[0] != p[2] or p[1] != p[3]: - adapt = "{}-adapt".format(node.output[0]) - tensors[adapt] = self.handler.pad( - tensors.get(node.input[0]), None, p, [-2, -1] - ) - tensors[node.output[0]] = self.handler.maxPool( - tensors[adapt], - tensors.get(node.output[0]), - k[0], - k[1], - d[0], - d[1], - 0, - 0, - s[0], - s[1], - ceil_mode, - ) - else: - tensors[node.output[0]] = self.handler.maxPool( - tensors[node.input[0]], - tensors.get(node.output[0]), - k[0], - k[1], - d[0], - d[1], - p[0], - p[1], - s[0], - s[1], - ceil_mode, - ) - elif node.op_type == "AveragePool": - attributes = _parse_attribute( - node, - { - "kernel_shape": None, - "pads": [0, 0, 0, 0], - "strides": [1, 1], - "ceil_mode": 0, - }, - ) - (k, p, s, ceil_mode) = ( - attributes[name] - for name in ["kernel_shape", "pads", "strides", "ceil_mode"] - ) - if p[0] != p[2] or p[1] != p[3]: - adapt = "{}-adapt".format(node.output[0]) - tensors[adapt] = self.handler.pad( - tensors.get(node.input[0]), None, p, [-2, -1] - ) - tensors[node.output[0]] = self.handler.avgPool( - tensors[adapt], - tensors.get(node.output[0]), - k[0], - k[1], - 1, - 1, - 0, - 0, - s[0], - s[1], - ceil_mode, - ) - else: - tensors[node.output[0]] = self.handler.avgPool( - tensors[node.input[0]], - tensors.get(node.output[0]), - k[0], - k[1], - 1, - 1, - p[0], - p[1], - s[0], - s[1], - ceil_mode, - ) - elif node.op_type == "GlobalAveragePool": - [_, _, h, w] = tensors[node.input[0]].shape() + else: tensors[node.output[0]] = self.handler.avgPool( tensors[node.input[0]], tensors.get(node.output[0]), - h, - w, + k[0], + k[1], 1, 1, - 0, - 0, + p[0], + p[1], + s[0], + s[1], + ceil_mode, + ) + elif node.op_type == "GlobalAveragePool": + [_, _, h, w] = tensors[node.input[0]].shape() + tensors[node.output[0]] = self.handler.avgPool( + tensors[node.input[0]], + tensors.get(node.output[0]), + h, + w, + 1, + 1, + 0, + 0, + 1, + 1, + 0, + ) + elif node.op_type == "Add": + tensors[node.output[0]] = self.handler.add( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Sub": + tensors[node.output[0]] = self.handler.sub( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Mul": + tensors[node.output[0]] = self.handler.mul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Div": + tensors[node.output[0]] = self.handler.div( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Pow": + tensors[node.output[0]] = self.handler.pow( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Min": + tensors[node.output[0]] = self.handler.min( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Max": + tensors[node.output[0]] = self.handler.max( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Relu": + tensors[node.output[0]] = self.handler.relu( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Gelu": + tensors[node.output[0]] = self.handler.gelu( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Sigmoid": + tensors[node.output[0]] = self.handler.sigmoid( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "HardSigmoid": + tensors[node.output[0]] = self.handler.hardSigmoid( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "HardSwish": + tensors[node.output[0]] = self.handler.hardSwish( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Tanh": + tensors[node.output[0]] = self.handler.tanh( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Softmax": + tensors[node.output[0]] = self.handler.softmax( + tensors[node.input[0]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis"), + -1, + ), + ) + elif node.op_type == "Abs": + tensors[node.output[0]] = self.handler.abs( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Sqrt": + tensors[node.output[0]] = self.handler.sqrt( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Neg": + tensors[node.output[0]] = self.handler.neg( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Shape": + tensors[node.output[0]] = self.handler.shape( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Identity": + tensors[node.output[0]] = self.handler.identity( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Flatten": + tensors[node.output[0]] = self.handler.flatten( + tensors[node.input[0]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis"), 1, - 1, - 0, - ) - elif node.op_type == "Add": - tensors[node.output[0]] = self.handler.add( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Sub": - tensors[node.output[0]] = self.handler.sub( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Mul": - tensors[node.output[0]] = self.handler.mul( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Div": - tensors[node.output[0]] = self.handler.div( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Pow": - tensors[node.output[0]] = self.handler.pow( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Min": - tensors[node.output[0]] = self.handler.min( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Max": - tensors[node.output[0]] = self.handler.max( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Relu": - tensors[node.output[0]] = self.handler.relu( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Gelu": - tensors[node.output[0]] = self.handler.gelu( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Sigmoid": - tensors[node.output[0]] = self.handler.sigmoid( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "HardSigmoid": - tensors[node.output[0]] = self.handler.hardSigmoid( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "HardSwish": - tensors[node.output[0]] = self.handler.hardSwish( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Tanh": - tensors[node.output[0]] = self.handler.tanh( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Softmax": - tensors[node.output[0]] = self.handler.softmax( - tensors[node.input[0]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis"), - -1, + ), + ) + elif node.op_type == "PRelu": + tensors[node.output[0]] = self.handler.pRelu( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Clip": + tensors[node.output[0]] = self.handler.clip( + tensors[node.input[0]], + tensors.get(node.output[0]), + next(_parse_data(data[node.input[1]]).__iter__(), None) + if len(node.input) > 1 + else None, + next(_parse_data(data[node.input[2]]).__iter__(), None) + if len(node.input) > 2 + else None, + ) + elif node.op_type == "Transpose": + perm = next( + (attr.ints for attr in node.attribute if attr.name == "perm"), + None, + ) + tensors[node.output[0]] = self.handler.transpose( + tensors[node.input[0]], + tensors.get(node.output[0]), + perm, + ) + elif node.op_type == "DepthToSpace": + blocksize = next( + (attr.i for attr in node.attribute if attr.name == "blocksize"), + None, + ) + mode = next( + (attr.s for attr in node.attribute if attr.name == "mode"), + None, + ) + tensors[node.output[0]] = self.handler.depthToSpace( + tensors[node.input[0]], + tensors.get(node.output[0]), + blocksize, + mode, + ) + elif node.op_type == "Reshape": + shape = _parse_data(data[node.input[1]]) + tensors[node.output[0]] = self.handler.reshape( + tensors[node.input[0]], + tensors.get(node.output[0]), + shape, + ) + elif node.op_type == "Resize": + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, + { + "antialias": 0, + "axes": None, + "coordinate_transformation_mode": "half_pixel", + "cubic_coeff_a": -0.75, + "exclude_outside": 0, + "extrapolation_value": 0.0, + "keep_aspect_ratio_policy": "none", + "mode": "nearest", + "nearest_mode": "none", + }, + ) + ( + axes, + keep_aspect_ratio_policy, + coordinate_transformation_mode, + mode, + nearest_mode, + ) = ( + attributes[name] + for name in [ + "axes", + "keep_aspect_ratio_policy", + "coordinate_transformation_mode", + "mode", + "nearest_mode", + ] + ) + if len(node.input) > 1: + roiVal = _parse_data(data[node.input[1]]) + else: + roiVal = [] + if len(node.input) > 2: + scalesVal = _parse_data(data[node.input[2]]) + else: + scalesVal = [] + if len(node.input) > 3: + sizesVal = _parse_data(data[node.input[3]]) + else: + sizesVal = [] + tensors[node.output[0]] = self.handler.resize( + tensors[node.input[0]], + output, + axes, + tensors[node.input[3]] if len(node.input) > 3 else None, + tensors[node.input[2]] if len(node.input) > 2 else None, + tensors[node.input[1]] if len(node.input) > 1 else None, + sizesVal, + scalesVal, + roiVal, + mode, + keep_aspect_ratio_policy, + nearest_mode, + coordinate_transformation_mode, + ) + elif node.op_type == "Squeeze": + axes = ( + _parse_data(data[node.input[1]]) + if len(node.input) > 1 + else None + ) + if axes is None: + axes = next( + ( + attr.ints + for attr in node.attribute + if attr.name == "axes" ), + [], ) - elif node.op_type == "Abs": - tensors[node.output[0]] = self.handler.abs( - tensors[node.input[0]], - tensors.get(node.output[0]), + tensors[node.output[0]] = self.handler.squeeze( + tensors[node.input[0]], + tensors.get(node.output[0]), + axes, + ) + elif node.op_type == "Unsqueeze": + axes = ( + _parse_data(data[node.input[1]]) + if len(node.input) > 1 + else None + ) + if axes is None: + axes = next( + ( + attr.ints + for attr in node.attribute + if attr.name == "axes" + ) ) - elif node.op_type == "Sqrt": - tensors[node.output[0]] = self.handler.sqrt( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Neg": - tensors[node.output[0]] = self.handler.neg( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Shape": - tensors[node.output[0]] = self.handler.shape( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Identity": - tensors[node.output[0]] = self.handler.identity( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Flatten": - tensors[node.output[0]] = self.handler.flatten( - tensors[node.input[0]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis"), - 1, + tensors[node.output[0]] = self.handler.unsqueeze( + tensors[node.input[0]], + tensors.get(node.output[0]), + axes, + ) + elif node.op_type == "Concat": + tensors[node.output[0]] = self.handler.concat( + [tensors[name] for name in node.input], + tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "axis")), + ) + elif node.op_type == "AttentionKVCache": + tensors[node.output[0]] = self.handler.attentionKVCache( + tensors[node.input[0]], + tensors[node.input[1]], + tensors[node.input[2]], + tensors[node.input[3]], + tensors[node.input[4]], + tensors[node.input[5]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Split": + split = ( + _parse_data(data[node.input[1]]) + if (len(node.input) > 1) + else None + ) + if split is None: + split = next( + ( + attr.ints + for attr in node.attribute + if attr.name == "split" ), + None, ) - elif node.op_type == "PRelu": - tensors[node.output[0]] = self.handler.pRelu( + for name, tensor in zip( + node.output, + self.handler.split( tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), + None, + next( + ( + attr.i + for attr in node.attribute + if attr.name == "axis" + ), + 0, + ), + split if split is not None else len(node.output), + ), + ): + tensors[name] = tensor + elif node.op_type == "Gather": + tensors[node.output[0]] = self.handler.gather( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis"), + 0, + ), + ) + elif node.op_type == "GatherElements": + tensors[node.output[0]] = self.handler.gatherElements( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis"), + 0, + ), + ) + elif node.op_type == "ReduceMean": + tensors[node.output[0]] = self.handler.reduceMean( + tensors[node.input[0]], + tensors.get(node.output[0]), + # NOTE(constroy): `axes` is an attribute until opset version 13. + next( + (attr.ints for attr in node.attribute if attr.name == "axes"), + None, + ), + next( + (attr.i for attr in node.attribute if attr.name == "keepdims"), + 1, ) - elif node.op_type == "Clip": - tensors[node.output[0]] = self.handler.clip( + != 0, + ) + elif node.op_type == "Slice": + + def clamp(nums): + MAX_INT = 0x7FFFFFFF + return [min(x, MAX_INT) for x in nums] + + tensors[node.output[0]] = self.handler.slice( + tensors[node.input[0]], + tensors.get(node.output[0]), + clamp(_parse_data(data[node.input[1]])), + clamp(_parse_data(data[node.input[2]])), + clamp(_parse_data(data[node.input[3]])) + if len(node.input) > 3 + else None, + clamp(_parse_data(data[node.input[4]])) + if len(node.input) > 4 + else None, + ) + elif node.op_type == "Pad": + tensors[node.output[0]] = self.handler.pad( + tensors[node.input[0]], + tensors.get(node.output[0]), + _parse_data(data[node.input[1]]), + _parse_data(data[node.input[3]]) if len(node.input) > 3 else None, + ) + elif node.op_type == "Dropout": + for name, tensor in zip( + node.output, + self.handler.dropout( tensors[node.input[0]], tensors.get(node.output[0]), - next(_parse_data(data[node.input[1]]).__iter__(), None) + tensors.get(node.output[1]) if len(node.output) > 1 else None, + _parse_data(data[node.input[1]])[0] if len(node.input) > 1 - else None, - next(_parse_data(data[node.input[2]]).__iter__(), None) + else 0.5, + _parse_data(data[node.input[2]])[0] if len(node.input) > 2 - else None, - ) - elif node.op_type == "Transpose": - perm = next( - (attr.ints for attr in node.attribute if attr.name == "perm"), - None, - ) - tensors[node.output[0]] = self.handler.transpose( + else False, + ), + ): + tensors[name] = tensor + elif node.op_type == "Cast": + tensors[node.output[0]] = self.handler.cast( + tensors[node.input[0]], + tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "to")), + ) + elif node.op_type == "ReduceSum": + if any(attr.name == "communicator" for attr in node.attribute): + # ReduceSum with communicator is treated as allReduceSum. + tensors[node.output[0]] = self.handler.allReduceSum( tensors[node.input[0]], tensors.get(node.output[0]), - perm, - ) - elif node.op_type == "DepthToSpace": - blocksize = next( - (attr.i for attr in node.attribute if attr.name == "blocksize"), - None, - ) - mode = next( - (attr.s for attr in node.attribute if attr.name == "mode"), - None, - ) - tensors[node.output[0]] = self.handler.depthToSpace( - tensors[node.input[0]], - tensors.get(node.output[0]), - blocksize, - mode, - ) - elif node.op_type == "Reshape": - shape = _parse_data(data[node.input[1]]) - tensors[node.output[0]] = self.handler.reshape( - tensors[node.input[0]], - tensors.get(node.output[0]), - shape, - ) - elif node.op_type == "Resize": - output = tensors.get(node.output[0]) - attributes = _parse_attribute( - node, - { - "antialias": 0, - "axes": None, - "coordinate_transformation_mode": "half_pixel", - "cubic_coeff_a": -0.75, - "exclude_outside": 0, - "extrapolation_value": 0.0, - "keep_aspect_ratio_policy": "none", - "mode": "nearest", - "nearest_mode": "none", - }, - ) - ( - axes, - keep_aspect_ratio_policy, - coordinate_transformation_mode, - mode, - nearest_mode, - ) = ( - attributes[name] - for name in [ - "axes", - "keep_aspect_ratio_policy", - "coordinate_transformation_mode", - "mode", - "nearest_mode", - ] ) + else: + # NOTE: `axes` is an attribute until opset version 13. if len(node.input) > 1: - roiVal = _parse_data(data[node.input[1]]) + axis = _parse_data(data[node.input[1]]) else: - roiVal = [] - if len(node.input) > 2: - scalesVal = _parse_data(data[node.input[2]]) - else: - scalesVal = [] - if len(node.input) > 3: - sizesVal = _parse_data(data[node.input[3]]) - else: - sizesVal = [] - tensors[node.output[0]] = self.handler.resize( - tensors[node.input[0]], - output, - axes, - tensors[node.input[3]] if len(node.input) > 3 else None, - tensors[node.input[2]] if len(node.input) > 2 else None, - tensors[node.input[1]] if len(node.input) > 1 else None, - sizesVal, - scalesVal, - roiVal, - mode, - keep_aspect_ratio_policy, - nearest_mode, - coordinate_transformation_mode, - ) - elif node.op_type == "Squeeze": - axes = ( - _parse_data(data[node.input[1]]) - if len(node.input) > 1 - else None - ) - if axes is None: - axes = next( - ( - attr.ints - for attr in node.attribute - if attr.name == "axes" - ), - [], - ) - tensors[node.output[0]] = self.handler.squeeze( - tensors[node.input[0]], - tensors.get(node.output[0]), - axes, - ) - elif node.op_type == "Unsqueeze": - axes = ( - _parse_data(data[node.input[1]]) - if len(node.input) > 1 - else None - ) - if axes is None: - axes = next( - ( - attr.ints - for attr in node.attribute - if attr.name == "axes" - ) - ) - tensors[node.output[0]] = self.handler.unsqueeze( - tensors[node.input[0]], - tensors.get(node.output[0]), - axes, - ) - elif node.op_type == "Concat": - tensors[node.output[0]] = self.handler.concat( - [tensors[name] for name in node.input], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis") - ), - ) - elif node.op_type == "AttentionKVCache": - tensors[node.output[0]] = self.handler.attentionKVCache( - tensors[node.input[0]], - tensors[node.input[1]], - tensors[node.input[2]], - tensors[node.input[3]], - tensors[node.input[4]], - tensors[node.input[5]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Split": - split = ( - _parse_data(data[node.input[1]]) - if (len(node.input) > 1) - else None - ) - if split is None: - split = next( - ( - attr.ints - for attr in node.attribute - if attr.name == "split" - ), - None, - ) - for name, tensor in zip( - node.output, - self.handler.split( - tensors[node.input[0]], - None, - next( - ( - attr.i - for attr in node.attribute - if attr.name == "axis" - ), - 0, - ), - split if split is not None else len(node.output), - ), - ): - tensors[name] = tensor - elif node.op_type == "Gather": - tensors[node.output[0]] = self.handler.gather( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis"), - 0, - ), - ) - elif node.op_type == "GatherElements": - tensors[node.output[0]] = self.handler.gatherElements( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis"), - 0, - ), - ) - elif node.op_type == "ReduceMean": - tensors[node.output[0]] = self.handler.reduceMean( - tensors[node.input[0]], - tensors.get(node.output[0]), - # NOTE(constroy): `axes` is an attribute until opset version 13. - next( + axis = next( ( attr.ints for attr in node.attribute if attr.name == "axes" ), None, - ), + ) + keepdims = ( next( ( attr.i @@ -724,245 +812,153 @@ class OnnxStub: ), 1, ) - != 0, - ) - elif node.op_type == "Slice": - - def clamp(nums): - MAX_INT = 0x7FFFFFFF - return [min(x, MAX_INT) for x in nums] - - tensors[node.output[0]] = self.handler.slice( - tensors[node.input[0]], - tensors.get(node.output[0]), - clamp(_parse_data(data[node.input[1]])), - clamp(_parse_data(data[node.input[2]])), - clamp(_parse_data(data[node.input[3]])) - if len(node.input) > 3 - else None, - clamp(_parse_data(data[node.input[4]])) - if len(node.input) > 4 - else None, - ) - elif node.op_type == "Pad": - tensors[node.output[0]] = self.handler.pad( - tensors[node.input[0]], - tensors.get(node.output[0]), - _parse_data(data[node.input[1]]), - _parse_data(data[node.input[3]]) - if len(node.input) > 3 - else None, - ) - elif node.op_type == "Dropout": - for name, tensor in zip( - node.output, - self.handler.dropout( - tensors[node.input[0]], - tensors.get(node.output[0]), - tensors.get(node.output[1]) - if len(node.output) > 1 - else None, - _parse_data(data[node.input[1]])[0] - if len(node.input) > 1 - else 0.5, - _parse_data(data[node.input[2]])[0] - if len(node.input) > 2 - else False, - ), - ): - tensors[name] = tensor - elif node.op_type == "Cast": - tensors[node.output[0]] = self.handler.cast( - tensors[node.input[0]], - tensors.get(node.output[0]), - next((attr.i for attr in node.attribute if attr.name == "to")), - ) - elif node.op_type == "ReduceSum": - if any(attr.name == "communicator" for attr in node.attribute): - # ReduceSum with communicator is treated as allReduceSum. - tensors[node.output[0]] = self.handler.allReduceSum( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - else: - # NOTE: `axes` is an attribute until opset version 13. - if len(node.input) > 1: - axis = _parse_data(data[node.input[1]]) - else: - axis = next( - ( - attr.ints - for attr in node.attribute - if attr.name == "axes" - ), - None, - ) - keepdims = ( - next( - ( - attr.i - for attr in node.attribute - if attr.name == "keepdims" - ), - 1, - ) - != 0 - ) - - tensors[node.output[0]] = self.handler.reduceSum( - tensors[node.input[0]], - tensors.get(node.output[0]), - axis, - keepdims, - ) - elif node.op_type == "AllReduceSum": - tensors[node.output[0]] = self.handler.allReduceSum( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllReduceProd": - tensors[node.output[0]] = self.handler.allReduceProd( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllReduceMin": - tensors[node.output[0]] = self.handler.allReduceMin( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllReduceMax": - tensors[node.output[0]] = self.handler.allReduceMax( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllReduceAvg": - tensors[node.output[0]] = self.handler.allReduceAvg( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllGather": - for name, tensor in zip( - node.output, - self.handler.allGather( - tensors[node.input[0]], - None, - len(node.output), - ), - ): - tensors[name] = tensor - elif node.op_type == "Broadcast": - tensors[node.output[0]] = self.handler.broadcast( - tensors[node.input[0]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "root"), - 0, - ), - ) - elif node.op_type == "Send": - source = next( - (attr.i for attr in node.attribute if attr.name == "source"), - 0, - ) - destination = next( - ( - attr.i - for attr in node.attribute - if attr.name == "destination" - ), - 0, + != 0 ) - self.handler.send( + tensors[node.output[0]] = self.handler.reduceSum( + tensors[node.input[0]], + tensors.get(node.output[0]), + axis, + keepdims, + ) + elif node.op_type == "AllReduceSum": + tensors[node.output[0]] = self.handler.allReduceSum( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceProd": + tensors[node.output[0]] = self.handler.allReduceProd( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceMin": + tensors[node.output[0]] = self.handler.allReduceMin( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceMax": + tensors[node.output[0]] = self.handler.allReduceMax( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceAvg": + tensors[node.output[0]] = self.handler.allReduceAvg( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllGather": + for name, tensor in zip( + node.output, + self.handler.allGather( tensors[node.input[0]], - source, - destination, None, - ) - elif node.op_type == "Recv": - source = next( - (attr.i for attr in node.attribute if attr.name == "source"), + len(node.output), + ), + ): + tensors[name] = tensor + elif node.op_type == "Broadcast": + tensors[node.output[0]] = self.handler.broadcast( + tensors[node.input[0]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "root"), 0, - ) - destination = next( - ( - attr.i - for attr in node.attribute - if attr.name == "destination" - ), - 0, - ) + ), + ) + elif node.op_type == "Send": + source = next( + (attr.i for attr in node.attribute if attr.name == "source"), + 0, + ) + destination = next( + (attr.i for attr in node.attribute if attr.name == "destination"), + 0, + ) - for attr in node.attribute: - if attr.name == "shape": - shapeBasic = attr.ints - shape = [] - for item in shapeBasic: - shape.append(item) + self.handler.send( + tensors[node.input[0]], + source, + destination, + None, + ) + elif node.op_type == "Recv": + source = next( + (attr.i for attr in node.attribute if attr.name == "source"), + 0, + ) + destination = next( + (attr.i for attr in node.attribute if attr.name == "destination"), + 0, + ) - for attr in node.attribute: - if attr.name == "dataType": - outputType = attr.i - tensors[node.output[0]] = self.handler.recv( - tensors.get(node.output[0]), - source, - destination, - shape, - outputType, - None, - ) - elif node.op_type == "Expand": - shape = _parse_data(data[node.input[1]]) - tensors[node.output[0]] = self.handler.expand( - tensors[node.input[0]], - tensors.get(node.output[0]), - shape, - ) - elif node.op_type == "Erf": - tensors[node.output[0]] = self.handler.erf( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Where": - tensors[node.output[0]] = self.handler.where( - tensors[node.input[1]], - tensors[node.input[2]], - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Constant": - output_name = node.output[0] - attributes = _parse_attribute(node) - tensor = attributes["value"] - dims = [d for d in tensor.dims] - tensors[output_name] = self.handler.tensor(dims, tensor.data_type) - data[output_name] = tensor - tensors[output_name].set_weight() - elif node.op_type == "LRN": - attributes = _parse_attribute( - node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1} - ) - (alpha, beta, bias, size) = ( - attributes[name] for name in ["alpha", "beta", "bias", "size"] - ) - tensors[node.output[0]] = self.handler.lrn( - tensors[node.input[0]], - tensors.get(node.output[0]), - alpha, - beta, - bias, - size, - ) - else: - raise Exception('Unsupported operator "{}"'.format(node.op_type)) - new_node_name.append(node.name) - # update the node_list - node_list = list(set(node_name) - set(new_node_name)) + for attr in node.attribute: + if attr.name == "shape": + shapeBasic = attr.ints + shape = [] + for item in shapeBasic: + shape.append(item) + + for attr in node.attribute: + if attr.name == "dataType": + outputType = attr.i + tensors[node.output[0]] = self.handler.recv( + tensors.get(node.output[0]), + source, + destination, + shape, + outputType, + None, + ) + elif node.op_type == "Expand": + shape = _parse_data(data[node.input[1]]) + tensors[node.output[0]] = self.handler.expand( + tensors[node.input[0]], + tensors.get(node.output[0]), + shape, + ) + elif node.op_type == "Erf": + tensors[node.output[0]] = self.handler.erf( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Where": + tensors[node.output[0]] = self.handler.where( + tensors[node.input[1]], + tensors[node.input[2]], + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Constant": + output_name = node.output[0] + attributes = _parse_attribute(node) + tensor = attributes["value"] + dims = [d for d in tensor.dims] + tensors[output_name] = self.handler.tensor(dims, tensor.data_type) + data[output_name] = tensor + tensors[output_name].set_weight() + elif node.op_type == "LRN": + attributes = _parse_attribute( + node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1} + ) + (alpha, beta, bias, size) = ( + attributes[name] + for name in ["alpha", "beta", "bias", "size"] + ) + tensors[node.output[0]] = self.handler.lrn( + tensors[node.input[0]], + tensors.get(node.output[0]), + alpha, + beta, + bias, + size, + ) + else: + raise Exception('Unsupported operator "{}"'.format(node.op_type)) ################################ # Allocate memory space for data ################################ - self.handler.data_malloc() + self.handler.data_malloc(self.use_naive_allocator) ################################# # Copy in data to tensor objects @@ -993,6 +989,9 @@ class OnnxStub: # assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) obj.copyin_numpy(to_array(tensor)) + for name, obj in tensors.items(): + self.tensors[name] = obj + for output in model.graph.output: self.outputs[output.name] = tensors[output.name] @@ -1237,7 +1236,7 @@ class OnnxStub: axes, ) ) - ctx.push_node(make_node(ty.name, inputs, outputs, name)) + ctx.push_node(make_node(ty.name, inputs, outputs, name)) elif ty == backend.OpTypeId.Concat: axis = backend.concat_axis_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) @@ -1335,7 +1334,7 @@ class OnnxStub: return ctx.build(name) def init(self) -> None: - self.handler.data_malloc() + self.handler.data_malloc(self.use_naive_allocator) def optimize(self) -> None: self.handler.optimize() @@ -1351,7 +1350,7 @@ class OnnxStub: oldTensor = self.inputs[oldInput] self.handler.change_shape(newInput, oldTensor.fuid()) self.handler.shape_infer() - self.handler.data_malloc() + self.handler.data_malloc(self.use_naive_allocator) def getShape(self, name: str) -> List[int]: if name in self.inputs: @@ -1414,10 +1413,3 @@ def _parse_data_fp16(tensor: TensorProto): def _take_shape_dim(shape: TensorShapeProto) -> List[int]: return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] - - -def _analyse_node(node: NodeProto, tensors) -> bool: - for i in node.input: - if i not in tensors: - return True - return False diff --git a/src/bang/bang_runtime.cc b/src/bang/bang_runtime.cc index 2f16b500..c4b426d2 100644 --- a/src/bang/bang_runtime.cc +++ b/src/bang/bang_runtime.cc @@ -16,8 +16,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, std::map opCnt; for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = - KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/core/graph.cc b/src/core/graph.cc index 5eb67402..ac90344a 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -87,48 +87,33 @@ string GraphObj::toString() const { } bool GraphObj::topo_sort() { - if (this->sorted) + if (this->sorted) { return true; - - // std::unordered_set inputs; - std::unordered_set waiting(this->ops.begin(), this->ops.end()); + } std::vector sorted; - - while (!waiting.empty()) { + std::unordered_set flags; + sorted.reserve(ops.size()); + flags.reserve(ops.size()); + while (sorted.size() < ops.size()) { // Any node is move to sorted in this loop. auto modified = false; - // Find head nodes. - for (auto it = waiting.begin(); it != waiting.end();) { - const auto &this_inputs = (*it)->getInputs(); - // If none of the input tensors is in waiting list, - // this node is a head node. - const auto is_head = std::all_of( - this_inputs.begin(), this_inputs.end(), [&](const auto &input) { - auto src = input->getSource(); - return src // If the source node is in the waiting - // list, means that this node is not the - // head node. - ? waiting.find(src) == waiting.end() - // This tensor has no source node, - // it must be a input tensor. - : (/*inputs.insert(input),*/ true); - }); - // Moves head node to sorted. - if (is_head) { + for (auto const &op : ops) { + if (auto const &inputs = op->getInputs(); + flags.find(op.get()) == flags.end() && + std::all_of(inputs.begin(), inputs.end(), + [&flags](auto const &input) { + auto ptr = input->getSource().get(); + return !ptr || flags.find(ptr) != flags.end(); + })) { modified = true; - sorted.emplace_back(std::move(*it)); - it = waiting.erase(it); - } else { - ++it; + sorted.emplace_back(op); + flags.insert(op.get()); } } - // Waiting list never modifies during a pass, - // sorting fails. if (!modified) { return false; } } - // Done. this->ops = std::move(sorted); return this->sorted = true; } @@ -182,7 +167,10 @@ void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) { // note: behavior may not match running in non-naive mode, and it may // not reproduce the bug for (auto &tensor : tensors) { - tensor->dataMalloc(); + if (!tensor->isWeight() || + (tensor->isWeight() && !weightAllocated)) { + tensor->dataMalloc(); + } } return; } diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 4d64d433..f1ae8849 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -17,8 +17,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { std::map opCnt; for (auto &op : graph->getOperators()) { - auto kernelAttrs = - KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); @@ -66,8 +65,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const { std::map opCnt; for (auto &op : graph->getOperators()) { - auto kernelAttrs = - KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 0676646a..b92cb18f 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -25,8 +25,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { auto &perfEngine = PerfEngine::getInstance(); for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = - KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); @@ -48,8 +47,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const { std::map opCnt; for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(), - DataType::Float32}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/cuda/cuda_utility.cu b/src/cuda/cuda_utility.cu index 83490959..e38910b9 100644 --- a/src/cuda/cuda_utility.cu +++ b/src/cuda/cuda_utility.cu @@ -1,4 +1,6 @@ +#include "core/data_type.h" #include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" #include __global__ void cudaPrintFloatImpl(float *x, int len) { @@ -18,4 +20,55 @@ void cudaPrintFloat(float *x, int len) { cudaDeviceSynchronize(); } +void cudaPrintTensor(const Tensor &tensor) { + cudaPrintFloat(tensor->getRawDataPtr(), tensor->size()); +} + +cudnnDataType_t cudnnDataTypeConvert(DataType dataType) { + if (dataType == DataType::Float32) { + return CUDNN_DATA_FLOAT; + } + if (dataType == DataType::Double) { + return CUDNN_DATA_DOUBLE; + } + if (dataType == DataType::Float16) { + return CUDNN_DATA_HALF; + } + if (dataType == DataType::Int8) { + return CUDNN_DATA_INT8; + } + if (dataType == DataType::Int32) { + return CUDNN_DATA_INT32; + } + if (dataType == DataType::UInt8) { + return CUDNN_DATA_UINT8; + } + if (dataType == DataType::BFloat16) { + return CUDNN_DATA_BFLOAT16; + } + if (dataType == DataType::Int64) { + return CUDNN_DATA_INT64; + } + if (dataType == DataType::Bool) { + return CUDNN_DATA_BOOLEAN; + } + IT_ASSERT(false, "Unsupported data type"); +} + +cudaDataType cublasDataTypeConvert(DataType dataType) { + switch (dataType.getIndex()) { + case 1: + return CUDA_R_32F; + // case 3: + // return CUDA_R_8I; + case 10: + return CUDA_R_16F; + case 11: + return CUDA_R_64F; + // case 16: + // return CUDA_R_16BF; + default: + IT_ASSERT(false, "MatMul Unsupported data type"); + } +} } // namespace infini diff --git a/src/kernels/bang/activation.cc b/src/kernels/bang/activation.cc index 1d7b0c20..bc970760 100644 --- a/src/kernels/bang/activation.cc +++ b/src/kernels/bang/activation.cc @@ -11,6 +11,7 @@ class UnaryCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -50,6 +51,7 @@ class RoundCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -80,6 +82,7 @@ class PReluCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -119,6 +122,7 @@ class SoftmaxCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -215,15 +219,12 @@ class SigmoidCnnl : public UnaryCnnl { float getCoef() const override { return 0.0; } }; -REGISTER_KERNEL(Device::BANG, OpType::Relu, DataType::Float32, ReluCnnl, - "Relu_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::PRelu, DataType::Float32, PReluCnnl, - "PRelu_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl, - "Sigmoid_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl, - "Round_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Softmax, DataType::Float32, SoftmaxCnnl, - "Softmax_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl, + "Sigmoid_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Softmax, SoftmaxCnnl, + "Softmax_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/activation_backward.cc b/src/kernels/bang/activation_backward.cc index cc70afce..c2c3baa6 100644 --- a/src/kernels/bang/activation_backward.cc +++ b/src/kernels/bang/activation_backward.cc @@ -10,6 +10,7 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const yData = (op->getInputs(0)->getRawDataPtr()); @@ -81,11 +82,11 @@ class TanhBackwardCnnl : public ActivationBackwardCnnl { float getCoef() const override { return 0.0; } }; -REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, DataType::Float32, - ReluBackwardCnnl, "ReluBackward_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, DataType::Float32, - SigmoidBackwardCnnl, "SigmoidBackward_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, DataType::Float32, - TanhBackwardCnnl, "TanhBackward_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, ReluBackwardCnnl, + "ReluBackward_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, SigmoidBackwardCnnl, + "SigmoidBackward_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, TanhBackwardCnnl, + "TanhBackward_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/batchnorm.cc b/src/kernels/bang/batchnorm.cc index a1bc81c0..31aba547 100644 --- a/src/kernels/bang/batchnorm.cc +++ b/src/kernels/bang/batchnorm.cc @@ -7,6 +7,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const input = (op->getInputs(0)->getRawDataPtr()); @@ -101,7 +102,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, DataType::Float32, - BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, BatchNormCnnl, + "BatchNorm_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/cast.cc b/src/kernels/bang/cast.cc index 267e4c2d..db281080 100644 --- a/src/kernels/bang/cast.cc +++ b/src/kernels/bang/cast.cc @@ -212,7 +212,6 @@ class CastCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Cast, DataType::Float32, CastCnnl, - "Cast_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Cast, CastCnnl, "Cast_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/ceil.cc b/src/kernels/bang/ceil.cc index c3d0f3d0..ce77741d 100644 --- a/src/kernels/bang/ceil.cc +++ b/src/kernels/bang/ceil.cc @@ -7,6 +7,7 @@ class CeilCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -35,7 +36,6 @@ class CeilCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Ceil, DataType::Float32, CeilCnnl, - "Ceil_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Ceil, CeilCnnl, "Ceil_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/clip.cc b/src/kernels/bang/clip.cc index 12b71fdc..fdb682f0 100644 --- a/src/kernels/bang/clip.cc +++ b/src/kernels/bang/clip.cc @@ -7,6 +7,7 @@ class ClipCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -30,7 +31,6 @@ class ClipCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Clip, DataType::Float32, ClipCnnl, - "Clip_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Clip, ClipCnnl, "Clip_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/concat.cc b/src/kernels/bang/concat.cc index ab535879..dae092c5 100644 --- a/src/kernels/bang/concat.cc +++ b/src/kernels/bang/concat.cc @@ -7,6 +7,7 @@ class ConcatCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); int num = op->numInputs(); int axis = op->getDim(); @@ -50,6 +51,5 @@ class ConcatCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Concat, DataType::Float32, ConcatCnnl, - "Concat_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Concat, ConcatCnnl, "Concat_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/conv.cc b/src/kernels/bang/conv.cc index d9ff3df8..24d8a3fd 100644 --- a/src/kernels/bang/conv.cc +++ b/src/kernels/bang/conv.cc @@ -7,6 +7,7 @@ class ConvCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -151,6 +152,5 @@ class ConvCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Conv, DataType::Float32, ConvCnnl, - "Conv_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Conv, ConvCnnl, "Conv_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/conv_trans.cc b/src/kernels/bang/conv_trans.cc index a081e279..ce93fc9a 100644 --- a/src/kernels/bang/conv_trans.cc +++ b/src/kernels/bang/conv_trans.cc @@ -7,6 +7,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -76,6 +77,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, DataType::Float32, - ConvTransCnnl, "ConvTrans_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, ConvTransCnnl, + "ConvTrans_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/convbpfilter.cc b/src/kernels/bang/convbpfilter.cc index b360cedb..f3e9ec94 100644 --- a/src/kernels/bang/convbpfilter.cc +++ b/src/kernels/bang/convbpfilter.cc @@ -7,6 +7,7 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -154,6 +155,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter, DataType::Float32, - ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter, + ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/det.cc b/src/kernels/bang/det.cc index 182baaa7..eeb197b6 100644 --- a/src/kernels/bang/det.cc +++ b/src/kernels/bang/det.cc @@ -7,6 +7,7 @@ class DetCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -42,6 +43,5 @@ class DetCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Det, DataType::Float32, DetCnnl, - "Det_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Det, DetCnnl, "Det_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/element_wise.cc b/src/kernels/bang/element_wise.cc index 9c1d95b4..e919e7d1 100644 --- a/src/kernels/bang/element_wise.cc +++ b/src/kernels/bang/element_wise.cc @@ -11,6 +11,7 @@ class ElementWiseCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -74,6 +75,7 @@ class LogicOpCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -127,6 +129,7 @@ class BitComputeCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -179,6 +182,7 @@ class DivCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -231,6 +235,7 @@ class MaximumCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -282,6 +287,7 @@ class MinimumCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -333,6 +339,7 @@ class MSELossCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -389,6 +396,7 @@ class PowerCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -442,6 +450,7 @@ class FloorDivCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -494,6 +503,7 @@ class FloorModCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -546,6 +556,7 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -658,62 +669,48 @@ class BitNotCnnl : public BitComputeCnnl { // CNNL_BLEFT_SHIFT_OP_V2; } // }; -REGISTER_KERNEL(Device::BANG, OpType::Add, DataType::Float32, AddCnnl, - "Add_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Sub, DataType::Float32, SubCnnl, - "Sub_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Mul, DataType::Float32, MulCnnl, - "Mul_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Add, AddCnnl, "Add_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Sub, SubCnnl, "Sub_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Mul, MulCnnl, "Mul_cnnl_BANG"); -REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32, DivCnnl, - "Div_cnnl_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Max, DataType::Float32, MaximumCnnl, - "Maximum_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Min, DataType::Float32, MinimumCnnl, - "Minimum_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::MSELoss, DataType::Float32, MSELossCnnl, - "MSELoss_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, PowerCnnl, - "Power_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, DataType::Float32, FloorDivCnnl, - "FloorDiv_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::FloorMod, DataType::Float32, FloorModCnnl, - "FloorMod_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, DataType::Float32, - SquaredDifferenceCnnl, "SquaredDifference_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Equal, DataType::Float32, EqualCnnl, - "Equal_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Greater, DataType::Float32, - GreaterThanCnnl, "GreaterThan_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, DataType::Float32, - GreaterEqualCnnl, "GreaterEqual_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Less, DataType::Float32, LessThanCnnl, - "LessThan_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, DataType::Float32, - LessEqualCnnl, "LessEqual_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::And, DataType::Float32, AndCnnl, - "And_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Or, DataType::Float32, OrCnnl, - "Or_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Xor, DataType::Float32, XorCnnl, - "Xor_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Not, DataType::Float32, NotCnnl, - "Not_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, DataType::Float32, BitAndCnnl, - "BitAnd_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, DataType::Float32, BitOrCnnl, - "BitOr_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, DataType::Float32, BitXorCnnl, - "BitXor_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, DataType::Float32, BitNotCnnl, - "BitNot_cnnl_BANG_Float32"); -// REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift, DataType::Float32, +REGISTER_KERNEL(Device::BANG, OpType::Div, DivCnnl, "Div_cnnl"); +REGISTER_KERNEL(Device::BANG, OpType::Max, MaximumCnnl, "Maximum_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Min, MinimumCnnl, "Minimum_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::MSELoss, MSELossCnnl, + "MSELoss_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Pow, PowerCnnl, "Power_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, FloorDivCnnl, + "FloorDiv_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::FloorMod, FloorModCnnl, + "FloorMod_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, SquaredDifferenceCnnl, + "SquaredDifference_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Equal, EqualCnnl, "Equal_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Greater, GreaterThanCnnl, + "GreaterThan_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, GreaterEqualCnnl, + "GreaterEqual_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Less, LessThanCnnl, "LessThan_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, LessEqualCnnl, + "LessEqual_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::And, AndCnnl, "And_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Or, OrCnnl, "Or_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Xor, XorCnnl, "Xor_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Not, NotCnnl, "Not_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, BitAndCnnl, + "BitAnd_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, BitOrCnnl, "BitOr_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, BitXorCnnl, + "BitXor_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, BitNotCnnl, + "BitNot_cnnl_BANG"); +// REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift, // BitLeftShiftCnnl, -// "BitLeftShift_cnnl_BANG_Float32"); -// REGISTER_KERNEL(Device::BANG, OpType::BitRightShift, DataType::Float32, +// "BitLeftShift_cnnl_BANG"); +// REGISTER_KERNEL(Device::BANG, OpType::BitRightShift, // BitRightShiftCnnl, -// "BitRightShift_cnnl_BANG_Float32"); -// REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, +// "BitRightShift_cnnl_BANG"); +// REGISTER_KERNEL(Device::BANG, OpType::Pow, // ElementWiseBang, -// "Pow_Bang_Float32"); +// "Pow_Bang"); }; // namespace infini diff --git a/src/kernels/bang/erf.cc b/src/kernels/bang/erf.cc index 5f1c0985..dcf8eacd 100644 --- a/src/kernels/bang/erf.cc +++ b/src/kernels/bang/erf.cc @@ -7,6 +7,7 @@ class ErfCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -36,7 +37,6 @@ class ErfCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Erf, DataType::Float32, ErfCnnl, - "Erf_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Erf, ErfCnnl, "Erf_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/exp.cc b/src/kernels/bang/exp.cc index fa71be72..4b3d88ab 100644 --- a/src/kernels/bang/exp.cc +++ b/src/kernels/bang/exp.cc @@ -7,6 +7,7 @@ class ExpCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -36,7 +37,6 @@ class ExpCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Exp, DataType::Float32, ExpCnnl, - "Exp_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Exp, ExpCnnl, "Exp_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/fill.cc b/src/kernels/bang/fill.cc index c3f75311..c2de64d5 100644 --- a/src/kernels/bang/fill.cc +++ b/src/kernels/bang/fill.cc @@ -7,6 +7,7 @@ class FillCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -29,7 +30,6 @@ class FillCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Fill, DataType::Float32, FillCnnl, - "Fill_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Fill, FillCnnl, "Fill_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/floor.cc b/src/kernels/bang/floor.cc index dd049d1d..83d8b505 100644 --- a/src/kernels/bang/floor.cc +++ b/src/kernels/bang/floor.cc @@ -7,6 +7,7 @@ class FloorCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -35,7 +36,7 @@ class FloorCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Floor, DataType::Float32, FloorCnnl, +REGISTER_KERNEL(Device::BANG, OpType::Floor, FloorCnnl, "Floor_cnnl_BANG_Float32"); }; // namespace infini diff --git a/src/kernels/bang/gather.cc b/src/kernels/bang/gather.cc index dc3ee636..97fa395c 100644 --- a/src/kernels/bang/gather.cc +++ b/src/kernels/bang/gather.cc @@ -7,6 +7,7 @@ class GatherCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -49,7 +50,6 @@ class GatherCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Gather, DataType::Float32, GatherCnnl, - "Gather_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Gather, GatherCnnl, "Gather_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/hardtanh.cc b/src/kernels/bang/hardtanh.cc index 2cdb89fe..1f91084e 100644 --- a/src/kernels/bang/hardtanh.cc +++ b/src/kernels/bang/hardtanh.cc @@ -7,6 +7,7 @@ class HardtanhCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -30,7 +31,7 @@ class HardtanhCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Hardtanh, DataType::Float32, HardtanhCnnl, - "Hardtanh_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Hardtanh, HardtanhCnnl, + "Hardtanh_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/l2loss.cc b/src/kernels/bang/l2loss.cc index 7fb5d3a8..deb127be 100644 --- a/src/kernels/bang/l2loss.cc +++ b/src/kernels/bang/l2loss.cc @@ -7,6 +7,7 @@ class L2LossCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -28,7 +29,6 @@ class L2LossCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::L2Loss, DataType::Float32, L2LossCnnl, - "L2Loss_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::L2Loss, L2LossCnnl, "L2Loss_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/layer_norm.cc b/src/kernels/bang/layer_norm.cc index 231177c5..acd36624 100644 --- a/src/kernels/bang/layer_norm.cc +++ b/src/kernels/bang/layer_norm.cc @@ -8,6 +8,7 @@ class LayerNormCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const inputData = (op->getInputs(0)->getRawDataPtr()); @@ -58,7 +59,7 @@ class LayerNormCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, DataType::Float32, - LayerNormCnnl, "LayerNorm_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, LayerNormCnnl, + "LayerNorm_BANG"); }; // namespace infini diff --git a/src/kernels/bang/log.cc b/src/kernels/bang/log.cc index 6237992e..c2a3e566 100644 --- a/src/kernels/bang/log.cc +++ b/src/kernels/bang/log.cc @@ -7,6 +7,7 @@ class LogCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -51,7 +52,6 @@ class LogCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Log, DataType::Float32, LogCnnl, - "Log_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Log, LogCnnl, "Log_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/lrn.cc b/src/kernels/bang/lrn.cc index 4183f0fd..14bca5fb 100644 --- a/src/kernels/bang/lrn.cc +++ b/src/kernels/bang/lrn.cc @@ -7,6 +7,7 @@ class LRNCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -56,7 +57,6 @@ class LRNCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::LRN, DataType::Float32, LRNCnnl, - "LRN_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::LRN, LRNCnnl, "LRN_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/matmul.cc b/src/kernels/bang/matmul.cc index 368d6b1c..09780067 100644 --- a/src/kernels/bang/matmul.cc +++ b/src/kernels/bang/matmul.cc @@ -8,6 +8,7 @@ class MatmulCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); auto input_num = op->numInputs(); @@ -107,6 +108,5 @@ class MatmulCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::MatMul, DataType::Float32, MatmulCnnl, - "Matmul_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::MatMul, MatmulCnnl, "Matmul_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/negtensor.cc b/src/kernels/bang/negtensor.cc index 02c5c37c..12377610 100644 --- a/src/kernels/bang/negtensor.cc +++ b/src/kernels/bang/negtensor.cc @@ -7,6 +7,7 @@ class NegTensorCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -35,7 +36,6 @@ class NegTensorCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Neg, DataType::Float32, NegTensorCnnl, - "Neg_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Neg, NegTensorCnnl, "Neg_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/pad.cc b/src/kernels/bang/pad.cc index c2503ca0..e8aafa1a 100644 --- a/src/kernels/bang/pad.cc +++ b/src/kernels/bang/pad.cc @@ -7,6 +7,7 @@ class PadCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -57,7 +58,6 @@ class PadCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Pad, DataType::Float32, PadCnnl, - "Pad_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Pad, PadCnnl, "Pad_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/pooling.cc b/src/kernels/bang/pooling.cc index f3cf04bc..90a0637f 100644 --- a/src/kernels/bang/pooling.cc +++ b/src/kernels/bang/pooling.cc @@ -8,6 +8,7 @@ class PoolingCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const inData = (op->getInputs(0)->getRawDataPtr()); void *const outData = (op->getOutput()->getRawDataPtr()); @@ -68,8 +69,8 @@ class avgPoolCnnl : public PoolingCnnl { } }; -REGISTER_KERNEL(Device::BANG, OpType::MaxPool, DataType::Float32, maxPoolCnnl, - "MaxPool_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::AveragePool, DataType::Float32, - avgPoolCnnl, "AvgPool_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::MaxPool, maxPoolCnnl, + "MaxPool_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::AveragePool, avgPoolCnnl, + "AvgPool_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/reciprocal.cc b/src/kernels/bang/reciprocal.cc index 6ac3f334..7b61c2ca 100644 --- a/src/kernels/bang/reciprocal.cc +++ b/src/kernels/bang/reciprocal.cc @@ -7,6 +7,7 @@ class ReciprocalCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -35,7 +36,7 @@ class ReciprocalCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, DataType::Float32, - ReciprocalCnnl, "Reciprocal_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, ReciprocalCnnl, + "Reciprocal_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/reduce.cc b/src/kernels/bang/reduce.cc index 88d1e645..810aca72 100644 --- a/src/kernels/bang/reduce.cc +++ b/src/kernels/bang/reduce.cc @@ -9,6 +9,7 @@ class ReduceCnnlBase : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -73,9 +74,9 @@ class ReduceSumCnnl : public ReduceCnnlBase { cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_ADD; } }; -REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32, - ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, DataType::Float32, - ReduceSumCnnl, "ReduceSum_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, ReduceMeanCnnl, + "ReduceMean_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, ReduceSumCnnl, + "ReduceSum_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/reshape.cc b/src/kernels/bang/reshape.cc index f5628a7b..cd876bf1 100644 --- a/src/kernels/bang/reshape.cc +++ b/src/kernels/bang/reshape.cc @@ -13,9 +13,9 @@ class CopyBang : public BangKernelWithoutConfig { auto dim = op->getInputs(0)->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dim.size(), - dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8, + dim.size() * op->getDType().getSize(), dim.data())); cnnlStatus_t stat = cnnlCopy(context->cnnlHandle(), aDesc, inData, aDesc, outData); if (stat != CNNL_STATUS_SUCCESS) @@ -25,13 +25,8 @@ class CopyBang : public BangKernelWithoutConfig { } }; // reshape/flatten/identity all act as copying from input to output. -REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Float32, CopyBang, - "Reshape_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Int64, CopyBang, - "Reshape_BANG_Int64"); -REGISTER_KERNEL(Device::BANG, OpType::Flatten, DataType::Float32, CopyBang, - "Flatten_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Identity, DataType::Float32, CopyBang, - "Identity_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Reshape, CopyBang, "Reshape_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Flatten, CopyBang, "Flatten_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Identity, CopyBang, "Identity_BANG"); } // namespace infini diff --git a/src/kernels/bang/rsqrt.cc b/src/kernels/bang/rsqrt.cc index 0da3c74d..66e63e0a 100644 --- a/src/kernels/bang/rsqrt.cc +++ b/src/kernels/bang/rsqrt.cc @@ -7,6 +7,7 @@ class RsqrtCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -36,7 +37,6 @@ class RsqrtCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Rsqrt, DataType::Float32, RsqrtCnnl, - "Rsqrt_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Rsqrt, RsqrtCnnl, "Rsqrt_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/split.cc b/src/kernels/bang/split.cc index bf3f8123..397b5063 100644 --- a/src/kernels/bang/split.cc +++ b/src/kernels/bang/split.cc @@ -7,6 +7,7 @@ class SplitCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); int num = op->numOutputs(); int axis = op->getDim(); @@ -49,6 +50,5 @@ class SplitCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Split, DataType::Float32, SplitCnnl, - "Split_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Split, SplitCnnl, "Split_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/sqrt.cc b/src/kernels/bang/sqrt.cc index 52fea02a..a1ed85c9 100644 --- a/src/kernels/bang/sqrt.cc +++ b/src/kernels/bang/sqrt.cc @@ -7,6 +7,7 @@ class SqrtCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -36,7 +37,6 @@ class SqrtCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Sqrt, DataType::Float32, SqrtCnnl, - "Sqrt_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Sqrt, SqrtCnnl, "Sqrt_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/transpose.cc b/src/kernels/bang/transpose.cc index ff2783b5..7dedd21d 100644 --- a/src/kernels/bang/transpose.cc +++ b/src/kernels/bang/transpose.cc @@ -7,6 +7,7 @@ class TransposeCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -52,6 +53,7 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -101,9 +103,9 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Transpose, DataType::Float32, - TransposeCnnl, "Transpose_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Transpose, TransposeCnnl, + "Transpose_cnnl_BANG"); -REGISTER_KERNEL(Device::BANG, OpType::DepthToSpace, DataType::Float32, - DepthToSpaceCnnl, "DepthToSpace_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::DepthToSpace, DepthToSpaceCnnl, + "DepthToSpace_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/trigon.cc b/src/kernels/bang/trigon.cc index b4842b95..989858c4 100644 --- a/src/kernels/bang/trigon.cc +++ b/src/kernels/bang/trigon.cc @@ -9,6 +9,7 @@ class TrigonCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -150,29 +151,17 @@ class ATanHCnnl : public TrigonCnnl { } }; -REGISTER_KERNEL(Device::BANG, OpType::Sin, DataType::Float32, SinCnnl, - "Sin_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Cos, DataType::Float32, CosCnnl, - "Cos_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Tan, DataType::Float32, TanCnnl, - "Tan_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Asin, DataType::Float32, ASinCnnl, - "ASin_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Acos, DataType::Float32, ACosCnnl, - "ACos_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Atan, DataType::Float32, ATanCnnl, - "ATan_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Sinh, DataType::Float32, SinHCnnl, - "SinH_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Cosh, DataType::Float32, CosHCnnl, - "CosH_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Tanh, DataType::Float32, TanHCnnl, - "TanH_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Asinh, DataType::Float32, ASinHCnnl, - "ASinH_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Acosh, DataType::Float32, ACosHCnnl, - "ACosH_cnnl_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::Atanh, DataType::Float32, ATanHCnnl, - "ATanH_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Sin, SinCnnl, "Sin_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Cos, CosCnnl, "Cos_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Tan, TanCnnl, "Tan_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Asin, ASinCnnl, "ASin_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Acos, ACosCnnl, "ACos_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Atan, ATanCnnl, "ATan_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Sinh, SinHCnnl, "SinH_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Cosh, CosHCnnl, "CosH_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Tanh, TanHCnnl, "TanH_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Asinh, ASinHCnnl, "ASinH_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Acosh, ACosHCnnl, "ACosH_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Atanh, ATanHCnnl, "ATanH_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/where.cc b/src/kernels/bang/where.cc index 725b63e0..8786f3fd 100644 --- a/src/kernels/bang/where.cc +++ b/src/kernels/bang/where.cc @@ -7,6 +7,7 @@ class WhereCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -67,7 +68,6 @@ class WhereCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Where, DataType::Float32, WhereCnnl, - "Where_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Where, WhereCnnl, "Where_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/cpu/concat.cc b/src/kernels/cpu/concat.cc index 5dd73866..156a16af 100644 --- a/src/kernels/cpu/concat.cc +++ b/src/kernels/cpu/concat.cc @@ -3,9 +3,9 @@ namespace infini { -template class NaiveConcat : public CpuKernelWithoutConfig { - void compute(const Operator &_op, - const RuntimeObj *context) const override { +class NaiveConcat : public CpuKernelWithoutConfig { + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); auto inputs = op->getInputs(), outputs = op->getOutputs(); auto dim = op->getDim(); @@ -41,11 +41,25 @@ template class NaiveConcat : public CpuKernelWithoutConfig { } } } + + void compute(const Operator &_op, + const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); + } + } }; -REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::UInt32, - NaiveConcat, "ConcatNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::Float32, - NaiveConcat, "ConcatNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Concat, NaiveConcat, "ConcatNaive_CPU"); } // namespace infini diff --git a/src/kernels/cpu/conv.cc b/src/kernels/cpu/conv.cc index b0ffa724..9300c72a 100644 --- a/src/kernels/cpu/conv.cc +++ b/src/kernels/cpu/conv.cc @@ -3,9 +3,9 @@ namespace infini { -template class NaiveConv : public CpuKernelWithoutConfig { - void compute(const Operator &_op, - const RuntimeObj *context) const override { +class NaiveConv : public CpuKernelWithoutConfig { + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); T *iptr = op->getInputs(0)->getRawDataPtr(); T *wptr = op->getInputs(1)->getRawDataPtr(); @@ -50,11 +50,25 @@ template class NaiveConv : public CpuKernelWithoutConfig { } } } + + void compute(const Operator &_op, + const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); + } + } }; -REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::UInt32, - NaiveConv, "ConvNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::Float32, NaiveConv, - "ConvNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Conv, NaiveConv, "ConvNaive_CPU"); } // namespace infini diff --git a/src/kernels/cpu/element_wise.cc b/src/kernels/cpu/element_wise.cc index ff03350c..98e974d3 100644 --- a/src/kernels/cpu/element_wise.cc +++ b/src/kernels/cpu/element_wise.cc @@ -3,10 +3,45 @@ #include "utils/operator_utils.h" namespace infini { -template class NativeElementWise : public CpuKernelWithoutConfig { - virtual T doCompute(T val0, T val1) const = 0; - void compute(const Operator &_op, - const RuntimeObj *context) const override { +class NativeElementWise : public CpuKernelWithoutConfig { + template static T addCompute(T val0, T val1) { + return val0 + val1; + } + + template static T subCompute(T val0, T val1) { + return val0 - val1; + } + + template static T mulCompute(T val0, T val1) { + return val0 * val1; + } + + template static T divCompute(T val0, T val1) { + return (T)(val0 / val1); + } + + template static T equalCompute(T val0, T val1) { + return (T)(val0 == val1); + } + + template static T greaterOrEqualCompute(T val0, T val1) { + return (T)(val0 >= val1); + } + + template static T greaterCompute(T val0, T val1) { + return (T)(val0 > val1); + } + + template static T lessOrEqualCompute(T val0, T val1) { + return (T)(val0 <= val1); + } + + template static T lessCompute(T val0, T val1) { + return (T)(val0 < val1); + } + + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); T *inptr0 = op->getInputs(0)->getRawDataPtr(); T *inptr1 = op->getInputs(1)->getRawDataPtr(); @@ -35,77 +70,77 @@ template class NativeElementWise : public CpuKernelWithoutConfig { Shape strideB = getStride(b); auto n = op->getOutput()->size(); + T (*_doCompute)(T val0, T val1); + switch (op->getOpType().underlying()) { + case OpType::Add: + _doCompute = addCompute; + break; + case OpType::Sub: + _doCompute = subCompute; + break; + case OpType::Mul: + _doCompute = mulCompute; + break; + case OpType::Div: + _doCompute = divCompute; + break; + case OpType::Equal: + _doCompute = equalCompute; + break; + case OpType::GreaterOrEqual: + _doCompute = greaterOrEqualCompute; + break; + case OpType::Greater: + _doCompute = greaterCompute; + break; + case OpType::LessOrEqual: + _doCompute = lessOrEqualCompute; + break; + case OpType::Less: + _doCompute = lessCompute; + break; + default: + IT_TODO_HALT(); + } + for (size_t i = 0; i < n; ++i) { auto shapeIndexC = locate_index(i, shapeC); auto indexA = delocate_index(shapeIndexC, a, strideA); auto indexB = delocate_index(shapeIndexC, b, strideB); - outptr[i] = doCompute(inptr0[indexA], inptr1[indexB]); + outptr[i] = _doCompute(inptr0[indexA], inptr1[indexB]); + } + } + + void compute(const Operator &_op, + const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); } } }; -template class NaiveAdd : public NativeElementWise { - T doCompute(T val0, T val1) const override { return val0 + val1; } -}; -template class NaiveSub : public NativeElementWise { - T doCompute(T val0, T val1) const override { return val0 - val1; } -}; -template class NaiveMul : public NativeElementWise { - T doCompute(T val0, T val1) const override { return val0 * val1; } -}; -template class NaiveDiv : public NativeElementWise { - T doCompute(T val0, T val1) const override { return (T)(val0 / val1); } -}; -template class NaiveEqual : public NativeElementWise { - T doCompute(T val0, T val1) const override { return (T)(val0 == val1); } -}; -template class NaiveGreaterEqual : public NativeElementWise { - T doCompute(T val0, T val1) const override { return (T)(val0 >= val1); } -}; -template class NaiveGreaterThan : public NativeElementWise { - T doCompute(T val0, T val1) const override { return (T)(val0 > val1); } -}; -template class NaiveLessEqual : public NativeElementWise { - T doCompute(T val0, T val1) const override { return (T)(val0 <= val1); } -}; -template class NaiveLessThan : public NativeElementWise { - T doCompute(T val0, T val1) const override { return (T)(val0 < val1); } -}; - -REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd, - "addNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::Float32, NaiveAdd, - "addNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::UInt32, NaiveSub, - "subNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::Float32, NaiveSub, - "subNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::UInt32, NaiveMul, - "mulNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::Float32, NaiveMul, - "mulNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv, - "divNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv, - "divNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::UInt32, - NaiveEqual, "equalNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::Float32, - NaiveEqual, "equalNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::UInt32, - NaiveGreaterEqual, "greaterEqualNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::Float32, - NaiveGreaterEqual, "greaterEqualNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::UInt32, - NaiveGreaterThan, "greaterThanNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::Float32, - NaiveGreaterThan, "greaterThanNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::UInt32, - NaiveLessEqual, "lessEqualNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::Float32, - NaiveLessEqual, "lessEqualNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::UInt32, - NaiveLessThan, "lessEqualNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::Float32, - NaiveLessThan, "lessEqualNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Add, NativeElementWise, "addNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Sub, NativeElementWise, "subNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Mul, NativeElementWise, "mulNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Div, NativeElementWise, "divNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Equal, NativeElementWise, + "equalNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, NativeElementWise, + "greaterEqualNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Greater, NativeElementWise, + "greaterThanNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, NativeElementWise, + "lessEqualNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Less, NativeElementWise, + "lessEqualNaive_CPU"); }; // namespace infini diff --git a/src/kernels/cpu/matmul.cc b/src/kernels/cpu/matmul.cc index 248cb60b..6a863402 100644 --- a/src/kernels/cpu/matmul.cc +++ b/src/kernels/cpu/matmul.cc @@ -3,9 +3,9 @@ namespace infini { -template class NaiveMatmul : public CpuKernelWithoutConfig { - void compute(const Operator &_op, - const RuntimeObj *context) const override { +class NaiveMatmul : public CpuKernelWithoutConfig { + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet."); T *A = op->getInputs(0)->getRawDataPtr(); @@ -23,11 +23,25 @@ template class NaiveMatmul : public CpuKernelWithoutConfig { } } } + + void compute(const Operator &_op, + const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); + } + } }; -REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::UInt32, - NaiveMatmul, "MatmulNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::Float32, - NaiveMatmul, "MatmulNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::MatMul, NaiveMatmul, "MatmulNaive_CPU"); } // namespace infini diff --git a/src/kernels/cpu/membound.cc b/src/kernels/cpu/membound.cc index b6b6c7ee..a2fd6232 100644 --- a/src/kernels/cpu/membound.cc +++ b/src/kernels/cpu/membound.cc @@ -80,8 +80,8 @@ class MemboundInterpreter : public Kernel { } }; -REGISTER_KERNEL(Device::CPU, OpType::MemBound, DataType::UInt32, - MemboundInterpreter, "MemboundInterpreter_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::MemBound, MemboundInterpreter, + "MemboundInterpreter_CPU"); } // namespace infini diff --git a/src/kernels/cpu/pooling.cc b/src/kernels/cpu/pooling.cc index 1242e14f..a076011a 100644 --- a/src/kernels/cpu/pooling.cc +++ b/src/kernels/cpu/pooling.cc @@ -2,42 +2,10 @@ #include "core/kernel.h" namespace infini { -template class NativePooling : public CpuKernelWithoutConfig { - virtual T getPoolingValue(int kh, int kw, int posh, int posw, int ih, - int iw, T *inptr) const = 0; - void compute(const Operator &_op, - const RuntimeObj *context) const override { - auto op = as(_op); - T *inptr = op->getInputs(0)->getRawDataPtr(); - T *outptr = op->getOutput()->getRawDataPtr(); - const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS(); - const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); - if (dh != 1 || dw != 1) - IT_TODO_HALT(); // To support dailated pooling - auto outDim = op->getOutput()->getDims(); - int oh = outDim[2], ow = outDim[3]; - for (auto i = 0; i < n; i++) { - for (auto j = 0; j < c; j++) { - auto inoffset = i * (c * ih * iw) + j * ih * iw; - for (auto h = 0; h < oh; h++) { - for (auto w = 0; w < ow; w++) { - // TODO: verify ceil mode - T val = - getPoolingValue(kh, kw, h * sh - ph, w * sw - pw, - ih, iw, inptr + inoffset); - auto outoffset = - w + h * ow + j * (oh * ow) + i * (c * oh * ow); - outptr[outoffset] = val; - } - } - } - } - } -}; - -template class NaiveMaxPool : public NativePooling { - T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw, - T *inptr) const override { +class NativePooling : public CpuKernelWithoutConfig { + template + static T getMaxPoolingValue(int kh, int kw, int posh, int posw, int ih, + int iw, T *inptr) { T maxval = 0; for (auto k = 0; k < kh; k++) { for (auto l = 0; l < kw; l++) { @@ -53,11 +21,10 @@ template class NaiveMaxPool : public NativePooling { } return maxval; } -}; -template class NaiveAvgPool : public NativePooling { - T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw, - T *inptr) const override { + template + static T getAvgPoolingValue(int kh, int kw, int posh, int posw, int ih, + int iw, T *inptr) { T sum = 0; for (auto k = 0; k < kh; k++) { for (auto l = 0; l < kw; l++) { @@ -71,12 +38,70 @@ template class NaiveAvgPool : public NativePooling { } return T(sum / (kh * kw)); } + + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { + auto op = as(_op); + T *inptr = op->getInputs(0)->getRawDataPtr(); + T *outptr = op->getOutput()->getRawDataPtr(); + + const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS(); + const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + if (dh != 1 || dw != 1) + IT_TODO_HALT(); // To support dailated pooling + auto outDim = op->getOutput()->getDims(); + int oh = outDim[2], ow = outDim[3]; + + T(*_doCompute) + (int kh, int kw, int posh, int posw, int ih, int iw, T *inptr); + switch (op->getOpType().underlying()) { + case OpType::MaxPool: + _doCompute = getMaxPoolingValue; + break; + case OpType::AveragePool: + _doCompute = getAvgPoolingValue; + break; + default: + IT_TODO_HALT(); + } + + for (auto i = 0; i < n; i++) { + for (auto j = 0; j < c; j++) { + auto inoffset = i * (c * ih * iw) + j * ih * iw; + for (auto h = 0; h < oh; h++) { + for (auto w = 0; w < ow; w++) { + // TODO: verify ceil mode + T val = _doCompute(kh, kw, h * sh - ph, w * sw - pw, ih, + iw, inptr + inoffset); + auto outoffset = + w + h * ow + j * (oh * ow) + i * (c * oh * ow); + outptr[outoffset] = val; + } + } + } + } + } + + void compute(const Operator &_op, + const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); + } + } }; -REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::UInt32, - NaiveMaxPool, "maxPoolNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32, - NaiveMaxPool, "maxPoolNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::AveragePool, DataType::Float32, - NaiveAvgPool, "AvgPoolNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::MaxPool, NativePooling, + "maxPoolNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::AveragePool, NativePooling, + "avgPoolNaive_CPU"); } // namespace infini diff --git a/src/kernels/cpu/split.cc b/src/kernels/cpu/split.cc index 3ef0cea3..3da5ade8 100644 --- a/src/kernels/cpu/split.cc +++ b/src/kernels/cpu/split.cc @@ -3,9 +3,9 @@ namespace infini { -template class NaiveSplit : public CpuKernelWithoutConfig { - void compute(const Operator &_op, - const RuntimeObj *context) const override { +class NaiveSplit : public CpuKernelWithoutConfig { + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); auto inputs = op->getInputs(), outputs = op->getOutputs(); auto dim = op->getDim(); @@ -40,11 +40,24 @@ template class NaiveSplit : public CpuKernelWithoutConfig { } } } + void compute(const Operator &_op, + const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); + } + } }; -REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::UInt32, - NaiveSplit, "SplitNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::Float32, - NaiveSplit, "SplitNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Split, NaiveSplit, "SplitNaive_CPU"); } // namespace infini diff --git a/src/kernels/cpu/transpose.cc b/src/kernels/cpu/transpose.cc index 997c427e..46292d45 100644 --- a/src/kernels/cpu/transpose.cc +++ b/src/kernels/cpu/transpose.cc @@ -14,9 +14,9 @@ inline Shape idx2Pos(const Shape &shape, size_t idx) { return pos; } -template class NaiveTranspose : public CpuKernelWithoutConfig { - void compute(const Operator &_op, - const RuntimeObj *context) const override { +class NaiveTranspose : public CpuKernelWithoutConfig { + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); auto inputs = op->getInputs(), outputs = op->getOutputs(); const auto &inDim = inputs[0]->getDims(); @@ -35,11 +35,26 @@ template class NaiveTranspose : public CpuKernelWithoutConfig { outPtr[outIdx] = inPtr[inIdx]; } } + + void compute(const Operator &_op, + const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); + } + } }; -REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::UInt32, - NaiveTranspose, "TransposeNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::Float32, - NaiveTranspose, "TransposeNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Transpose, NaiveTranspose, + "TransposeNaive_CPU"); } // namespace infini diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index 3ea61b41..024d720a 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -4,25 +4,170 @@ #include "operators/softmax.h" namespace infini { -template class NativeUnary : public CpuKernelWithoutConfig { - virtual T doCompute(T val) const = 0; - void compute(const Operator &_op, - const RuntimeObj *context) const override { +class NativeUnary : public CpuKernelWithoutConfig { + template static T reluCompute(T val) { + return std::max(T(0), val); + } + + template static T sigmoidCompute(T val) { + return 1 / (1 + pow(E_CONSTANT, -val)); + } + + template static T hardSigmoidCompute(T val) { + return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5))); + } + + template static T hardSwishCompute(T val) { + return val * + std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5))); + } + + template static T tanhCompute(T val) { + return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) / + (pow(E_CONSTANT, val) + pow(E_CONSTANT, -val)); + } + + template static T absCompute(T val) { + return val < 0 ? -val : val; + } + + template static T sqrtCompute(T val) { return std::sqrt(val); } + + template static T cosCompute(T val) { return std::cos(val); } + + template static T sinCompute(T val) { return std::sin(val); } + + template static T tanCompute(T val) { return std::tan(val); } + + template static T sinhCompute(T val) { return std::sinh(val); } + + template static T coshCompute(T val) { return std::cosh(val); } + + template static T geluCompute(T val) { + return 0.5 * val * (1 + std::erf(val / std::sqrt(2))); + } + + template static T erfCompute(T val) { return std::erf(val); } + + template static T aCosCompute(T val) { return std::acos(val); } + + template static T aCoshCompute(T val) { + return std::acosh(val); + } + + template static T aSinCompute(T val) { return std::asin(val); } + + template static T aSinhCompute(T val) { + return std::asinh(val); + } + template static T aTanCompute(T val) { return std::atan(val); } + + template static T aTanhCompute(T val) { + return std::atanh(val); + } + template static T negCompute(T val) { return -val; } + + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); T *inptr = op->getInputs(0)->getRawDataPtr(); T *outptr = op->getOutput()->getRawDataPtr(); auto outDim = op->getOutput()->getDims(); auto n = op->getOutput()->size(); + + T (*_doCompute)(T val); + switch (op->getOpType().underlying()) { + case OpType::Relu: + _doCompute = reluCompute; + break; + case OpType::Gelu: + _doCompute = geluCompute; + break; + case OpType::Sigmoid: + _doCompute = sigmoidCompute; + break; + case OpType::HardSigmoid: + _doCompute = hardSigmoidCompute; + break; + case OpType::HardSwish: + _doCompute = hardSwishCompute; + break; + case OpType::Tanh: + _doCompute = tanhCompute; + break; + case OpType::Abs: + _doCompute = absCompute; + break; + case OpType::Sqrt: + _doCompute = sqrtCompute; + break; + case OpType::Erf: + _doCompute = erfCompute; + break; + case OpType::Neg: + _doCompute = negCompute; + break; + case OpType::Cos: + _doCompute = cosCompute; + break; + case OpType::Sin: + _doCompute = sinCompute; + break; + case OpType::Tan: + _doCompute = tanCompute; + break; + case OpType::Sinh: + _doCompute = sinhCompute; + break; + case OpType::Cosh: + _doCompute = coshCompute; + break; + case OpType::Acos: + _doCompute = aCosCompute; + break; + case OpType::Asin: + _doCompute = aSinCompute; + break; + case OpType::Asinh: + _doCompute = aSinhCompute; + break; + case OpType::Atan: + _doCompute = aTanCompute; + break; + case OpType::Atanh: + _doCompute = aTanhCompute; + break; + default: + IT_TODO_HALT(); + } + for (size_t offset = 0; offset < n; offset++) { - outptr[offset] = doCompute(inptr[offset]); + outptr[offset] = _doCompute(inptr[offset]); + } + } + + void compute(const Operator &_op, + const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); } } }; -template class NaiveSoftmax : public CpuKernelWithoutConfig { - void compute(const Operator &_op, - const RuntimeObj *context) const override { +class NaiveSoftmax : public CpuKernelWithoutConfig { + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); T *inptr = op->getInputs(0)->getRawDataPtr(); T *outptr = op->getOutput()->getRawDataPtr(); @@ -37,98 +182,28 @@ template class NaiveSoftmax : public CpuKernelWithoutConfig { outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum; } } -}; -template class NaiveRelu : public NativeUnary { - T doCompute(T val) const override { return std::max(T(0), val); } -}; -template class NaiveSigmoid : public NativeUnary { - T doCompute(T val) const override { - return 1 / (1 + pow(E_CONSTANT, -val)); - } -}; -template class NaiveHardSigmoid : public NativeUnary { - T doCompute(T val) const override { - return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5))); - } -}; -template class NaiveHardSwish : public NativeUnary { - T doCompute(T val) const override { - return val * - std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5))); - } -}; -template class NaiveTanh : public NativeUnary { - T doCompute(T val) const override { - return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) / - (pow(E_CONSTANT, val) + pow(E_CONSTANT, -val)); - } -}; -template class NaiveAbs : public NativeUnary { - T doCompute(T val) const override { return val < 0 ? -val : val; } -}; - -template class NaiveSqrt : public NativeUnary { - T doCompute(T val) const override { return std::sqrt(val); } -}; - -template class NaiveCos : public NativeUnary { - T doCompute(T val) const override { return std::cos(val); } -}; - -template class NaiveSin : public NativeUnary { - T doCompute(T val) const override { return std::sin(val); } -}; - -template class NaiveTan : public NativeUnary { - T doCompute(T val) const override { return std::tan(val); } -}; - -template class NaiveSinh : public NativeUnary { - T doCompute(T val) const override { return std::sinh(val); } -}; - -template class NaiveCosh : public NativeUnary { - T doCompute(T val) const override { return std::cosh(val); } -}; - -template class NaiveGelu : public NativeUnary { - T doCompute(T val) const override { - return 0.5 * val * (1 + std::erf(val / std::sqrt(2))); - } -}; - -template class NaiveErf : public NativeUnary { - T doCompute(T val) const override { return std::erf(val); } -}; - -template class NaiveACos : public NativeUnary { - T doCompute(T val) const override { return std::acos(val); } -}; - -template class NaiveACosh : public NativeUnary { - T doCompute(T val) const override { return std::acosh(val); } -}; - -template class NaiveASin : public NativeUnary { - T doCompute(T val) const override { return std::asin(val); } -}; - -template class NaiveASinh : public NativeUnary { - T doCompute(T val) const override { return std::asinh(val); } -}; - -template class NaiveATanh : public NativeUnary { - T doCompute(T val) const override { return std::atanh(val); } -}; - -template class NaiveNeg : public NativeUnary { - T doCompute(T val) const override { return -val; } -}; - -template class Clip : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); + } + } +}; + +class Clip : public CpuKernelWithoutConfig { + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); T *inptr = op->getInputs(0)->getRawDataPtr(); T *outptr = op->getOutput()->getRawDataPtr(); @@ -143,11 +218,28 @@ template class Clip : public CpuKernelWithoutConfig { : val; } } -}; -template class Log : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); + } + } +}; + +class Log : public CpuKernelWithoutConfig { + template + void doCompute(const Operator &_op, const RuntimeObj *context) const { auto op = as(_op); T *inptr = op->getInputs(0)->getRawDataPtr(); T *outptr = op->getOutput()->getRawDataPtr(); @@ -176,70 +268,50 @@ template class Log : public CpuKernelWithoutConfig { } } } + + void compute(const Operator &_op, + const RuntimeObj *context) const override { +#define CASE(N) \ + case N: \ + doCompute::t>(_op, context) + + int dataTypeIdx = _op->getDType().getIndex(); + switch (dataTypeIdx) { + CASE(1); // DataType::Float32 + break; + CASE(12); // DataType::UInt32 + break; + default: + IT_TODO_HALT(); + } + } }; -template class NaiveATan : public NativeUnary { - T doCompute(T val) const override { return std::atan(val); } -}; +REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary, + "hardSigmoidNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::HardSwish, NativeUnary, + "hardSwishNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Tanh, NativeUnary, "tanhNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Abs, NativeUnary, "absNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Sqrt, NativeUnary, "sqrtNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Erf, NativeUnary, "erfNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Neg, NativeUnary, "negNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Cos, NativeUnary, "Cos_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Sin, NativeUnary, "Sin_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Tan, NativeUnary, "Tan_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Sinh, NativeUnary, "Sinh_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Cosh, NativeUnary, "Cosh_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Acos, NativeUnary, "ACos_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Acosh, NativeUnary, "ACosh_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Asin, NativeUnary, "ASin_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Asinh, NativeUnary, "ASinh_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Atan, NativeUnary, "Atan_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Atanh, NativeUnary, "ATanh_CPU"); -REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32, - NaiveRelu, "reluNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu, - "reluNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::UInt32, NaiveGelu, - "geluNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::Float32, NaiveGelu, - "geluNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32, - NaiveSigmoid, "sigmoidNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32, - NaiveSigmoid, "sigmoidNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, DataType::Float32, - NaiveHardSigmoid, "hardSigmoidNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::HardSwish, DataType::Float32, - NaiveHardSwish, "hardSwishNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32, - NaiveTanh, "tanhNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh, - "tanhNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs, - "absNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs, - "absNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt, - "sqrtNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf, - "erfNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Neg, DataType::Float32, NaiveNeg, - "negNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32, - NaiveSoftmax, "softmaxNaive_CPU_uint32"); -REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32, - NaiveSoftmax, "softmaxNaive_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Clip, DataType::Float32, Clip, - "Clip_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Atan, DataType::Float32, NaiveATan, - "Atan_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Log, DataType::Float32, Log, - "Log_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Cos, DataType::Float32, NaiveCos, - "Cos_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Sin, DataType::Float32, NaiveSin, - "Sin_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Tan, DataType::Float32, NaiveTan, - "Tan_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Sinh, DataType::Float32, NaiveSinh, - "Sinh_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Cosh, DataType::Float32, NaiveCosh, - "Cosh_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Acos, DataType::Float32, NaiveACos, - "ACos_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Acosh, DataType::Float32, - NaiveACosh, "ACosh_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Asin, DataType::Float32, NaiveASin, - "ASin_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Asinh, DataType::Float32, - NaiveASinh, "ASinh_CPU_float32"); -REGISTER_KERNEL(Device::CPU, OpType::Atanh, DataType::Float32, - NaiveATanh, "ATanh_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Softmax, NaiveSoftmax, "softmaxNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Clip, Clip, "Clip_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Log, Log, "Log_CPU"); }; // namespace infini diff --git a/src/kernels/cuda/G2BMM.cc b/src/kernels/cuda/G2BMM.cc index cb69f76a..133e4c4d 100644 --- a/src/kernels/cuda/G2BMM.cc +++ b/src/kernels/cuda/G2BMM.cc @@ -48,13 +48,13 @@ class G2BMMCudnn : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); bool success = g2bmmKernel(op, context); IT_ASSERT(success); } }; -REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, DataType::Float32, G2BMMCudnn, - "G2BMM_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, G2BMMCudnn, "G2BMM_cuDNN_CUDA"); } // namespace infini diff --git a/src/kernels/cuda/GBMM.cc b/src/kernels/cuda/GBMM.cc index 06002850..392101ab 100644 --- a/src/kernels/cuda/GBMM.cc +++ b/src/kernels/cuda/GBMM.cc @@ -49,13 +49,13 @@ class GBMMCudnn : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); bool success = gbmmKernel(op, context); IT_ASSERT(success); } }; -REGISTER_KERNEL(Device::CUDA, OpType::GBMM, DataType::Float32, GBMMCudnn, - "GBMM_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::GBMM, GBMMCudnn, "GBMM_cuDNN_CUDA"); } // namespace infini diff --git a/src/kernels/cuda/all_gather.cc b/src/kernels/cuda/all_gather.cc index 187aea5c..261f9070 100644 --- a/src/kernels/cuda/all_gather.cc +++ b/src/kernels/cuda/all_gather.cc @@ -39,8 +39,8 @@ class AllGatherNCCL : public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::AllGather, DataType::Float32, - AllGatherNCCL, "AllGather_NCCL_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::AllGather, AllGatherNCCL, + "AllGather_NCCL_CUDA"); } // namespace infini #endif diff --git a/src/kernels/cuda/all_reduce.cc b/src/kernels/cuda/all_reduce.cc index ef60b991..8b64d2ab 100644 --- a/src/kernels/cuda/all_reduce.cc +++ b/src/kernels/cuda/all_reduce.cc @@ -13,15 +13,24 @@ class AllReduceNCCL : public CudaKernelWithoutConfig { auto context = dynamic_cast(_context); void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); - IT_ASSERT(op->getDType() == DataType::Float32); + ncclDataType_t ncclType = ncclFloat; + if (op->getDType() == DataType::Float16) { + ncclType = ncclFloat16; + } else if (op->getDType() == DataType::Int8) { + ncclType = ncclInt8; + } else if (op->getDType() == DataType::Float32) { + ncclType = ncclFloat; + } else { + IT_TODO_HALT(); + } size_t count = op->getInputs(0)->size(); ncclComm_t comm = dynamic_cast(context->getCommunicator()) .getNcclComm(); // TODO: Using default stream 0 for now. - checkNcclError(ncclAllReduce(input, output, count, ncclFloat, - getRedOp(), comm, 0)); + checkNcclError( + ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0)); } virtual ncclRedOp_t getRedOp() const = 0; @@ -43,16 +52,16 @@ class AllReduceAvgNCCL : public AllReduceNCCL { ncclRedOp_t getRedOp() const override { return ncclAvg; } }; -REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, DataType::Float32, - AllReduceSumNCCL, "AllReduce_Sum_NCCL_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, DataType::Float32, - AllReduceProdNCCL, "AllReduce_Prod_NCCL_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, DataType::Float32, - AllReduceMinNCCL, "AllReduce_Min_NCCL_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, DataType::Float32, - AllReduceMaxNCCL, "AllReduce_Max_NCCL_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, DataType::Float32, - AllReduceAvgNCCL, "AllReduce_Avg_NCCL_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, AllReduceSumNCCL, + "AllReduce_Sum_NCCL_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, AllReduceProdNCCL, + "AllReduce_Prod_NCCL_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, AllReduceMinNCCL, + "AllReduce_Min_NCCL_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, AllReduceMaxNCCL, + "AllReduce_Max_NCCL_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, AllReduceAvgNCCL, + "AllReduce_Avg_NCCL_CUDA"); } // namespace infini #endif diff --git a/src/kernels/cuda/attention_kvcache.cc b/src/kernels/cuda/attention_kvcache.cc index 0d21603a..52356d8d 100644 --- a/src/kernels/cuda/attention_kvcache.cc +++ b/src/kernels/cuda/attention_kvcache.cc @@ -40,6 +40,7 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute, public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { + IT_ASSERT(_op->getDType() == DataType::Float32); do_compute(_op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2], _op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5], @@ -47,6 +48,6 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute, } }; -REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32, - AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, AttentionKVCacheCuda, + "AttentionKVCache_CUDA"); } // namespace infini diff --git a/src/kernels/cuda/batch_norm.cc b/src/kernels/cuda/batch_norm.cc index 1df7313f..b083ad9c 100644 --- a/src/kernels/cuda/batch_norm.cc +++ b/src/kernels/cuda/batch_norm.cc @@ -10,6 +10,7 @@ class BatchNormCudnn : public CudaKernelWithoutConfig { auto op = as(_op); auto context = dynamic_cast(_context); cudnnStatus_t stat; + IT_ASSERT(op->getDType() == DataType::Float32); void *const inData = (op->getInputs(0)->getRawDataPtr()); void *const outData = (op->getOutput()->getRawDataPtr()); void *const meanData = (op->getInputs(1)->getRawDataPtr()); @@ -59,6 +60,6 @@ class BatchNormCudnn : public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, DataType::Float32, - BatchNormCudnn, "BatchNorm_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, BatchNormCudnn, + "BatchNorm_cuDNN_CUDA"); } // namespace infini diff --git a/src/kernels/cuda/broadcast.cc b/src/kernels/cuda/broadcast.cc index 79190491..6fb35914 100644 --- a/src/kernels/cuda/broadcast.cc +++ b/src/kernels/cuda/broadcast.cc @@ -25,8 +25,8 @@ class BroadcastNCCL : public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, DataType::Float32, - BroadcastNCCL, "Broadcast_NCCL_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, BroadcastNCCL, + "Broadcast_NCCL_CUDA"); } // namespace infini #endif diff --git a/src/kernels/cuda/clip.cc b/src/kernels/cuda/clip.cc index b4865504..55184eb9 100644 --- a/src/kernels/cuda/clip.cc +++ b/src/kernels/cuda/clip.cc @@ -9,7 +9,7 @@ class ClipCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - + IT_ASSERT(op->getDType() == DataType::Float32); void *const inputData = (op->getInputs(0)->getRawDataPtr()); void *const outputData = (op->getOutput()->getRawDataPtr()); auto min = op->getMin(); @@ -21,7 +21,6 @@ class ClipCuda : public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::Clip, DataType::Float32, ClipCuda, - "Clip_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Clip, ClipCuda, "Clip_CUDA"); }; // namespace infini diff --git a/src/kernels/cuda/conv.cc b/src/kernels/cuda/conv.cc index c020ed33..de7d8d09 100644 --- a/src/kernels/cuda/conv.cc +++ b/src/kernels/cuda/conv.cc @@ -1,10 +1,12 @@ #include "operators/conv.h" #include "core/kernel.h" #include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" #include #include #include #include + namespace infini { struct ConvCuDnnPerfRecordObj : public PerfRecordObj { @@ -56,8 +58,11 @@ class convCudnn : public Kernel { const ConvCuDnnPerfRecord &record) const { void *const inData = (op->getInputs(0)->getRawDataPtr()); void *const knData = (op->getInputs(1)->getRawDataPtr()); - if (op->getInputs().size() > 2) // Bias is not supported yet + // Bias is not supported yet + if (op->getInputs().size() > 2) { IT_TODO_HALT(); + } + auto cudnnDataType = cudnnDataTypeConvert(op->getDType()); // void *const biasData = (op->getInputs(2)->getRawDataPtr()); void *const outData = (op->getOutput()->getRawDataPtr()); @@ -72,27 +77,26 @@ class convCudnn : public Kernel { cudnnTensorDescriptor_t inDesc; checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); checkCudnnError(cudnnSetTensor4dDescriptor( - inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w)); + inDesc, CUDNN_TENSOR_NCHW, cudnnDataType, n, channels, h, w)); // get kernels cudnnFilterDescriptor_t knDesc; checkCudnnError(cudnnCreateFilterDescriptor(&knDesc)); - checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT, - CUDNN_TENSOR_NCHW, f, - channelsPerGrp, r, s)); + checkCudnnError(cudnnSetFilter4dDescriptor( + knDesc, cudnnDataType, CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s)); // get bias cudnnTensorDescriptor_t biasDesc; checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor( - biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1)); + checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW, + cudnnDataType, 1, f, 1, 1)); - // get convlution descriptor + // get convolution descriptor cudnnConvolutionDescriptor_t convDesc; checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc)); // TODO: CUDNN_CONVOLUTION is a tunable argument checkCudnnError(cudnnSetConvolution2dDescriptor( convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode], - CUDNN_DATA_FLOAT)); + cudnnDataType)); if (g > 1) { checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g)); } @@ -120,14 +124,14 @@ class convCudnn : public Kernel { assert(false); } + // get output descriptor int outn, outc, outh, outw; checkCudnnError(cudnnGetConvolution2dForwardOutputDim( convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw)); cudnnTensorDescriptor_t outDesc; checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW, - CUDNN_DATA_FLOAT, outn, outc, - outh, outw)); + checkCudnnError(cudnnSetTensor4dDescriptor( + outDesc, CUDNN_TENSOR_NCHW, cudnnDataType, outn, outc, outh, outw)); IT_ASSERT((vector{outn, outc, outh, outw}) == op->getOutput()->getDims(), "cuDNN output shape mismatches with OP output shape"); @@ -151,55 +155,9 @@ class convCudnn : public Kernel { inData, knDesc, knData, convDesc, ALGOS[record->algo], wsData, wsSize, &beta, outDesc, outData); - if (stat != CUDNN_STATUS_SUCCESS) + if (stat != CUDNN_STATUS_SUCCESS) { return false; - // TODO: - // // bias - // if (bias != nullptr) { - // auto sz = op.getOutputs()[0]->size(); - // // TODO: element wise - // t += sz * 2 / 400; - // } - // // act - // if (act != None) { - // stat = cudnnActivationForward(cudnnHandle(), actDesc, - // &alpha, inDesc, inData, - // &beta, outDesc, outData); - // checkCudaError(cudaDeviceSynchronize()); - // end = ch::high_resolution_clock::now(); - // if (stat != CUDNN_STATUS_SUCCESS) { - // durtime = INFINITY; - // break; - // } - // t += - // ch::duration_cast>(end - - // beg).count() * 1000; // ms - // } - - // best = ConvResult{durtime, ALGOS[i], wsSize, false}; - - // // w/ bias & act - // for (int j = 0; j < rounds + warmupRounds; ++j) { - // cudnnStatus_t stat; - // if (j == warmupRounds) { - // checkCudaError(cudaDeviceSynchronize()); - // beg = ch::high_resolution_clock::now(); - // } - // stat = cudnnConvolutionBiasActivationForward( - // cudnnHandle(), &alpha, inDesc, inData, knDesc, knData, - // convDesc, ALGOS[i], wsData, wsSize, &beta, outDesc, - // outData, biasDesc, biasData, actDesc, outDesc, outData); - // if (stat != CUDNN_STATUS_SUCCESS) { - // // checkCudnnError(stat); - // // Do not checkCudnnError since not all algorithms are - // // supported - // durtime_fuse = INFINITY; - // break; - // } - // } - - // Destories in CUDA does not require sync. But cuDNN does not state - // whether sync is required before destories. + } checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); @@ -238,10 +196,12 @@ class convCudnn : public Kernel { stat = cudnnGetConvolutionForwardWorkspaceSize( context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc, ALGOS[record.algo], &record.workspaceSize); - if (stat != CUDNN_STATUS_SUCCESS) + if (stat != CUDNN_STATUS_SUCCESS) { continue; - if (record.workspaceSize > context->getWorkspaceSize()) + } + if (record.workspaceSize > context->getWorkspaceSize()) { continue; + } CudaPtr wsData = context->getWorkspace(record.workspaceSize); float alpha = 1.f, beta = 0.f; @@ -249,8 +209,9 @@ class convCudnn : public Kernel { context->cudnnHandle(), &alpha, inDesc, inData, knDesc, knData, convDesc, ALGOS[record.algo], wsData, record.workspaceSize, &beta, outDesc, outData); - if (stat != CUDNN_STATUS_SUCCESS) + if (stat != CUDNN_STATUS_SUCCESS) { continue; + } record.time = timeit( [&]() { cudnnConvolutionForward(context->cudnnHandle(), &alpha, @@ -263,8 +224,9 @@ class convCudnn : public Kernel { // printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time); // Update the tune result - if (ret.time > record.time) + if (ret.time > record.time) { ret = record; + } checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); @@ -291,8 +253,7 @@ class convCudnn : public Kernel { } }; -REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn, - "Conv_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Conv, convCudnn, "Conv_cuDNN_CUDA"); REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json); } // namespace infini diff --git a/src/kernels/cuda/conv_half.cc b/src/kernels/cuda/conv_half.cc deleted file mode 100644 index 1f83b484..00000000 --- a/src/kernels/cuda/conv_half.cc +++ /dev/null @@ -1,261 +0,0 @@ -#include "core/kernel.h" -#include "cuda/cuda_runtime.h" -#include "operators/conv.h" -#include -#include -#include -#include - -namespace infini { - -struct ConvCuDnnPerfRecordObj : public PerfRecordObj { - int algo = 0; // cudnnConvolutionFwdAlgo_t - int mode = 1; - size_t workspaceSize = 100000; - bool fuseAct = false; - void to_json(json &j) override { - j["type"] = 1; - j["data"] = std::make_tuple(algo, mode, fuseAct, time, workspaceSize); - } - static PerfRecord from_json(const json &j) { - ConvCuDnnPerfRecordObj tmp; - auto [Algo, Mode, FuseAct, Time, WorkspaceSize] = - j["data"].get>(); - tmp.algo = Algo; - tmp.mode = Mode; - tmp.fuseAct = FuseAct; - tmp.time = Time; - tmp.workspaceSize = WorkspaceSize; - return make_ref(tmp); - } -}; - -using ConvCuDnnPerfRecord = Ref; - -class convCudnnFP16 : public Kernel { - - static constexpr int N_ALGO = 8; - static constexpr int N_MODE = 2; - static constexpr cudnnConvolutionFwdAlgo_t ALGOS[8] = { - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, - CUDNN_CONVOLUTION_FWD_ALGO_FFT, - CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED}; - - static constexpr cudnnConvolutionMode_t MODES[2] = { - CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION}; - - std::tuple - createCuDNNDescriptor(const Ref &op, - const ConvCuDnnPerfRecord &record) const { - void *const inData = (op->getInputs(0)->getRawDataPtr()); - void *const knData = (op->getInputs(1)->getRawDataPtr()); - // Bias is not supported yet - if (op->getInputs().size() > 2) { - IT_TODO_HALT(); - } - // void *const biasData = (op->getInputs(2)->getRawDataPtr()); - void *const outData = (op->getOutput()->getRawDataPtr()); - - const auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); - const int cpg = op->getChannelPerGroup(); - const int g = c / cpg; - const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); - - int channelsPerGrp = cpg, channels = c; - - // get inputs - cudnnTensorDescriptor_t inDesc; - checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor(inDesc, CUDNN_TENSOR_NCHW, - CUDNN_DATA_HALF, n, channels, - h, w)); /*fp16 type*/ - - // get kernels - cudnnFilterDescriptor_t knDesc; - checkCudnnError(cudnnCreateFilterDescriptor(&knDesc)); - checkCudnnError(cudnnSetFilter4dDescriptor( - knDesc, CUDNN_DATA_HALF, /*fp16 type*/ - CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s)); - // get bias - cudnnTensorDescriptor_t biasDesc; - checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW, - CUDNN_DATA_HALF, 1, f, 1, - 1)); /*fp16 type*/ - - // get convolution descriptor - cudnnConvolutionDescriptor_t convDesc; - checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc)); - // TODO: CUDNN_CONVOLUTION is a tunable argument - checkCudnnError(cudnnSetConvolution2dDescriptor( - convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode], - CUDNN_DATA_HALF)); /*fp16 type*/ - if (g > 1) { - checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g)); - } - - // get activation descriptor - cudnnActivationDescriptor_t actDesc; - checkCudnnError(cudnnCreateActivationDescriptor(&actDesc)); - // NOT_PROPAGATE_NAN is requierd by - // cudnnConvolotionBiasActivationForward - switch (op->getAct()) { - case ActType::Relu: - checkCudnnError(cudnnSetActivationDescriptor( - actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0)); - break; - case ActType::Sigmoid: - checkCudnnError(cudnnSetActivationDescriptor( - actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0)); - break; - case ActType::None: - checkCudnnError( - cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY, - CUDNN_NOT_PROPAGATE_NAN, 0)); - break; - default: - assert(false); - } - - // get output descriptor - int outn, outc, outh, outw; - checkCudnnError(cudnnGetConvolution2dForwardOutputDim( - convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw)); - cudnnTensorDescriptor_t outDesc; - checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW, - CUDNN_DATA_HALF, outn, outc, - outh, outw)); - IT_ASSERT((vector{outn, outc, outh, outw}) == - op->getOutput()->getDims(), - "cuDNN output shape mismatches with OP output shape"); - - return tuple(inData, knData, outData, inDesc, knDesc, biasDesc, - convDesc, actDesc, outDesc); - } - - bool cuDNNUnfused(const Ref &op, const ConvCuDnnPerfRecord &record, - const CudaRuntimeObj *context) const { - cudnnStatus_t stat; - - const auto &[inData, knData, outData, inDesc, knDesc, biasDesc, - convDesc, actDesc, outDesc] = - createCuDNNDescriptor(op, record); - size_t wsSize = record->workspaceSize; - CudaPtr wsData = context->getWorkspace(wsSize); - float alpha = 1.f, beta = 0.f; - - stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc, - inData, knDesc, knData, convDesc, - ALGOS[record->algo], wsData, wsSize, - &beta, outDesc, outData); - if (stat != CUDNN_STATUS_SUCCESS) { - return false; - } - checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); - checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); - checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); - checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc)); - checkCudnnError(cudnnDestroyFilterDescriptor(knDesc)); - checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); - return true; - } - - void compute(const Operator &op, const RuntimeObj *context) const override { - auto record = make_ref(); // with paramters in - // default ctor - compute(op, record, context); - } - - PerfRecord tune(const Operator &_op, - const RuntimeObj *_context) const override { - ConvCuDnnPerfRecordObj ret; - ret.time = std::numeric_limits::max(); - auto context = dynamic_cast(_context); - auto op = as(_op); - // Both modes have the same performance. Only run cross-correlation. - for (int mode = 1; mode < 2; mode++) { - // Try every possible algorithm of convolution - for (int algo = 0; algo < N_ALGO; algo++) { - auto recordRef = make_ref(); - auto &record = *recordRef; - record.mode = mode; - record.algo = algo; - cudnnStatus_t stat; - const auto &[inData, knData, outData, inDesc, knDesc, biasDesc, - convDesc, actDesc, outDesc] = - createCuDNNDescriptor(op, recordRef); - - // get workspace - stat = cudnnGetConvolutionForwardWorkspaceSize( - context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc, - ALGOS[record.algo], &record.workspaceSize); - if (stat != CUDNN_STATUS_SUCCESS) { - continue; - } - if (record.workspaceSize > context->getWorkspaceSize()) { - continue; - } - CudaPtr wsData = context->getWorkspace(record.workspaceSize); - float alpha = 1.f, beta = 0.f; - - stat = cudnnConvolutionForward( - context->cudnnHandle(), &alpha, inDesc, inData, knDesc, - knData, convDesc, ALGOS[record.algo], wsData, - record.workspaceSize, &beta, outDesc, outData); - if (stat != CUDNN_STATUS_SUCCESS) { - continue; - } - record.time = timeit( - [&]() { - cudnnConvolutionForward(context->cudnnHandle(), &alpha, - inDesc, inData, knDesc, knData, - convDesc, ALGOS[record.algo], - wsData, record.workspaceSize, - &beta, outDesc, outData); - }, - [&]() { context->sync(); }); - // printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time); - - // Update the tune result - if (ret.time > record.time) { - ret = record; - } - checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); - checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); - checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); - checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc)); - checkCudnnError(cudnnDestroyFilterDescriptor(knDesc)); - checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); - } - } - // printf("the best algo is %d, the best conv mode is %d\n", ret.algo, - // ret.mode); - IT_ASSERT(ret.time < std::numeric_limits::max(), "No valid " - "algorithm " - "found"); - return make_ref(ret); - } - - void compute(const Operator &_op, const PerfRecord &_record, - const RuntimeObj *_context) const override { - auto op = as(_op); - auto record = as(_record); - auto context = dynamic_cast(_context); - bool success = cuDNNUnfused(op, record, context); - IT_ASSERT(success); - } -}; - -REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float16, convCudnnFP16, - "Conv_cuDNN_CUDA_Float16"); - -} // namespace infini diff --git a/src/kernels/cuda/conv_transposed.cc b/src/kernels/cuda/conv_transposed.cc index 4bd1b5e9..259fe8bb 100644 --- a/src/kernels/cuda/conv_transposed.cc +++ b/src/kernels/cuda/conv_transposed.cc @@ -219,6 +219,7 @@ class convBackwardDataCudnn : public Kernel { void compute(const Operator &op, const RuntimeObj *context) const override { // with paramters in default ctor auto record = make_ref(); + IT_ASSERT(op->getDType() == DataType::Float32); compute(op, record, context); } @@ -300,8 +301,9 @@ class convBackwardDataCudnn : public Kernel { } }; -REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, DataType::Float32, - convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, DataType::Float32, - convBackwardDataCudnn, "ConvTranposedNHWC_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, convBackwardDataCudnn, + "ConvTranposed_cuDNN_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, convBackwardDataCudnn, + "ConvTranposedNHWC_cuDNN_CUDA"); + } // namespace infini diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index 8603c198..4a16de29 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -2,6 +2,7 @@ #include "cuda/cuda_element_wise.h" #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" namespace infini { class ElementWiseCudnn : public CudaKernelWithoutConfig { @@ -44,22 +45,21 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig { std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size())); std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size())); std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size())); + + auto cudnnDataType = cudnnDataTypeConvert(op->getDType()); // get inputs checkCudnnError(cudnnCreateTensorDescriptor(&aDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW, - CUDNN_DATA_FLOAT, a[0], a[1], - a[2], a[3])); + checkCudnnError(cudnnSetTensor4dDescriptor( + aDesc, CUDNN_TENSOR_NCHW, cudnnDataType, a[0], a[1], a[2], a[3])); checkCudnnError(cudnnCreateTensorDescriptor(&bDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor(bDesc, CUDNN_TENSOR_NCHW, - CUDNN_DATA_FLOAT, b[0], b[1], - b[2], b[3])); + checkCudnnError(cudnnSetTensor4dDescriptor( + bDesc, CUDNN_TENSOR_NCHW, cudnnDataType, b[0], b[1], b[2], b[3])); // get outputs checkCudnnError(cudnnCreateTensorDescriptor(&cDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor(cDesc, CUDNN_TENSOR_NCHW, - CUDNN_DATA_FLOAT, c[0], c[1], - c[2], c[3])); + checkCudnnError(cudnnSetTensor4dDescriptor( + cDesc, CUDNN_TENSOR_NCHW, cudnnDataType, c[0], c[1], c[2], c[3])); // get op descriptor cudnnOpTensorDescriptor_t opDesc; @@ -127,40 +127,33 @@ class ElementWiseCuda : public CudaKernelWithoutConfig { std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size())); std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size())); - if (op->getOpType() == OpType::Div) - div_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], - b[2], b[3], c[0], c[1], c[2], c[3]); - else if (op->getOpType() == OpType::Pow) - pow_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], - b[2], b[3], c[0], c[1], c[2], c[3]); - else if (op->getOpType() == OpType::Add) { - add_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], - b[2], b[3], c[0], c[1], c[2], c[3]); + const int dType = _op->getDType().getIndex(); + if (op->getOpType() == OpType::Div) { + div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0], + b[1], b[2], b[3], c[0], c[1], c[2], c[3]); + } else if (op->getOpType() == OpType::Add) { + add_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0], + b[1], b[2], b[3], c[0], c[1], c[2], c[3]); + } else if (op->getOpType() == OpType::Pow) { + pow_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0], + b[1], b[2], b[3], c[0], c[1], c[2], c[3]); } else if (op->getOpType() == OpType::Less) { - less_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], - b[2], b[3], c[0], c[1], c[2], c[3]); - } else + less_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], + b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3]); + } else { IT_TODO_HALT(); + } } }; -REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Float32, AddCudnn, - "Add_cuDNN_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Sub, DataType::Float32, SubCudnn, - "Sub_cuDNN_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Mul, DataType::Float32, MulCudnn, - "Mul_cuDNN_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Min, DataType::Float32, MinCudnn, - "Min_cuDNN_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn, - "Max_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Add, AddCudnn, "Add_cuDNN_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Sub, SubCudnn, "Sub_cuDNN_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Mul, MulCudnn, "Mul_cuDNN_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Min, MinCudnn, "Min_cuDNN_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Max, MaxCudnn, "Max_cuDNN_CUDA"); + +REGISTER_KERNEL(Device::CUDA, OpType::Div, ElementWiseCuda, "Div_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Pow, ElementWiseCuda, "Pow_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Less, ElementWiseCuda, "Less_CUDA"); -REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda, - "Div_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Int64, ElementWiseCuda, - "Add_CUDA_Int64"); -REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda, - "Pow__CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Less, DataType::Int64, ElementWiseCuda, - "Less__CUDA_Int64"); }; // namespace infini diff --git a/src/kernels/cuda/element_wise.cu b/src/kernels/cuda/element_wise.cu index 9d1b101a..98a12571 100644 --- a/src/kernels/cuda/element_wise.cu +++ b/src/kernels/cuda/element_wise.cu @@ -1,4 +1,5 @@ #include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" #include constexpr unsigned int num_threads() { return 32 * 4; } @@ -129,44 +130,113 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2, } } +#define CASE(OP, T) \ + _##OP##_kernel::t><<>>( \ + a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); + +#define SWITCH_DTYPE(OP, DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE(OP, 1) \ + break; \ + case 2: \ + CASE(OP, 2) \ + break; \ + case 3: \ + CASE(OP, 3) \ + break; \ + case 4: \ + CASE(OP, 4) \ + break; \ + case 5: \ + CASE(OP, 5) \ + break; \ + case 6: \ + CASE(OP, 6) \ + break; \ + case 7: \ + CASE(OP, 7) \ + break; \ + case 10: \ + CASE(OP, 10) \ + break; \ + case 11: \ + CASE(OP, 11) \ + break; \ + case 12: \ + CASE(OP, 12) \ + break; \ + case 13: \ + CASE(OP, 13) \ + break; \ + case 16: \ + CASE(OP, 16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } + namespace infini { -void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, - int b0, int b1, int b2, int b3, int c0, int c1, int c2, +void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, + int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3) { int blocksize = block_work_size(); int num = c0 * c1 * c2 * c3; int gridsize = (num + block_work_size() - 1) / block_work_size(); - _div_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, b1, - b2, b3, c0, c1, c2, c3); + SWITCH_DTYPE(div, dType) } -void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, - int b0, int b1, int b2, int b3, int c0, int c1, int c2, +void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, + int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3) { int blocksize = block_work_size(); int num = c0 * c1 * c2 * c3; int gridsize = (num + block_work_size() - 1) / block_work_size(); - _add_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, - b1, b2, b3, c0, c1, c2, c3); + SWITCH_DTYPE(add, dType) } -void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, - int b0, int b1, int b2, int b3, int c0, int c1, int c2, +void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, + int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3) { int blocksize = block_work_size(); int num = c0 * c1 * c2 * c3; int gridsize = (num + block_work_size() - 1) / block_work_size(); - _pow_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, b1, - b2, b3, c0, c1, c2, c3); + if (dType == 1) { + _pow_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, + b1, b2, b3, c0, c1, c2, c3); + } else if (dType == 3) { + _pow_kernel<<>>( + a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); + } else if (dType == 10) { + int a_size = a0 * a1 * a2 * a3; + int b_size = b0 * b1 * b2 * b3; + int c_size = c0 * c1 * c2 * c3; + vector a_float(a_size); + vector b_float(b_size); + vector c_float(c_size); + for (int i = 0; i < a_size; ++i) { + a_float[i] = __half2float(((half *)a)[i]); + } + for (int i = 0; i < b_size; ++i) { + b_float[i] = __half2float(((half *)b)[i]); + } + _pow_kernel<<>>( + a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0, + b1, b2, b3, c0, c1, c2, c3); + for (int i = 0; i < c_size; ++i) { + ((half *)c)[i] = __float2half(c_float[i]); + } + } else { + IT_TODO_HALT(); + } } -void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, - int b0, int b1, int b2, int b3, int c0, int c1, int c2, +void less_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, + int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3) { int blocksize = block_work_size(); int num = c0 * c1 * c2 * c3; int gridsize = (num + block_work_size() - 1) / block_work_size(); - _less_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, - b1, b2, b3, c0, c1, c2, c3); + SWITCH_DTYPE(less, dType) } }; // namespace infini diff --git a/src/kernels/cuda/expand.cc b/src/kernels/cuda/expand.cc index acbf5cd2..35b14f85 100644 --- a/src/kernels/cuda/expand.cc +++ b/src/kernels/cuda/expand.cc @@ -25,12 +25,12 @@ class ExpandCuda : public CudaKernelWithoutConfig { inputShape.data[i] = in_Shape[i]; outputsize *= out_Shape[i]; } - expandKernel((float *)inputData, (float *)outputData, nDims, outputsize, + const int dType = op->getDType().getIndex(); + expandKernel(dType, inputData, outputData, nDims, outputsize, inputShape, outputShape); } }; -REGISTER_KERNEL(Device::CUDA, OpType::Expand, DataType::Float32, ExpandCuda, - "Expand_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Expand, ExpandCuda, "Expand_CUDA"); }; // namespace infini diff --git a/src/kernels/cuda/expand.cu b/src/kernels/cuda/expand.cu index 09405d09..af92b9ce 100644 --- a/src/kernels/cuda/expand.cu +++ b/src/kernels/cuda/expand.cu @@ -1,12 +1,14 @@ #include "core/common.h" #include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" #include "utils/small_array.h" constexpr unsigned int num_threads() { return 32 * 4; } constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } -__global__ void _expandKernel(float *input, float *output, int nDims, +template +__global__ void _expandKernel(void *input, void *output, int nDims, int outputsize, infini::SmallArray inputShape, infini::SmallArray outputShape) { @@ -33,17 +35,64 @@ __global__ void _expandKernel(float *input, float *output, int nDims, temp *= inputShape.data[i]; v = v / outputShape.data[i]; } - output[outputIdx] = input[inputIdx]; + ((T *)output)[outputIdx] = ((T *)input)[inputIdx]; } } namespace infini { -void expandKernel(float *input, float *output, int nDims, int outputsize, - SmallArray inputShape, SmallArray outputShape) { + +#define CASE(T) \ + _expandKernel::t><<>>( \ + input, output, nDims, outputsize, inputShape, outputShape); + +#define SWITCH_DTYPE(DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE(1) \ + break; \ + case 2: \ + CASE(2) \ + break; \ + case 3: \ + CASE(3) \ + break; \ + case 4: \ + CASE(4) \ + break; \ + case 5: \ + CASE(5) \ + break; \ + case 6: \ + CASE(6) \ + break; \ + case 7: \ + CASE(7) \ + break; \ + case 10: \ + CASE(10) \ + break; \ + case 11: \ + CASE(11) \ + break; \ + case 12: \ + CASE(12) \ + break; \ + case 13: \ + CASE(13) \ + break; \ + case 16: \ + CASE(16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } + +void expandKernel(int dType, void *input, void *output, int nDims, + int outputsize, SmallArray inputShape, + SmallArray outputShape) { int blocksize = block_work_size(); int gridsize = (outputsize + block_work_size() - 1) / block_work_size(); - _expandKernel<<>>(input, output, nDims, outputsize, - inputShape, outputShape); + SWITCH_DTYPE(dType) } } // namespace infini diff --git a/src/kernels/cuda/extend.cc b/src/kernels/cuda/extend.cc index a5603e02..c8df7ff1 100644 --- a/src/kernels/cuda/extend.cc +++ b/src/kernels/cuda/extend.cc @@ -8,6 +8,7 @@ class ExtendCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto inData = op->getInputs(0)->getRawDataPtr(); auto outData = op->getOutputs()[0]->getRawDataPtr(); int blockSize = 1; @@ -22,6 +23,5 @@ class ExtendCuda : public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::Extend, DataType::Float32, ExtendCuda, - "Extend_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Extend, ExtendCuda, "Extend_CUDA"); } // namespace infini diff --git a/src/kernels/cuda/gather.cc b/src/kernels/cuda/gather.cc index 54e6bd10..4417e3b4 100644 --- a/src/kernels/cuda/gather.cc +++ b/src/kernels/cuda/gather.cc @@ -15,12 +15,23 @@ class GatherCuda : public CudaKernelWithoutConfig { GatherMetaData metaData; initGatherMetaData(metaData, op); - auto inData = input->getRawDataPtr(); - auto outData = op->getOutput()->getRawDataPtr(); - gather_kernel(inData, outData, metaData, op->getOutput()->size()); + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + if (op->getDType() == DataType::Float32) { + gather_kernel((float *)inputData, (float *)outputData, + metaData, op->getOutput()->size()); + } else if (op->getDType() == DataType::Float16) { + gather_kernel((half *)inputData, (half *)outputData, metaData, + op->getOutput()->size()); + } else if (op->getDType() == DataType::Int8) { + gather_kernel((int8_t *)inputData, (int8_t *)outputData, + metaData, op->getOutput()->size()); + } else { + IT_ASSERT(false); + } } }; -REGISTER_KERNEL(Device::CUDA, OpType::Gather, DataType::Float32, GatherCuda, - "Gather_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Gather, GatherCuda, "Gather_CUDA"); } // namespace infini diff --git a/src/kernels/cuda/gather.cu b/src/kernels/cuda/gather.cu index 8ffeeac9..c9dedd95 100644 --- a/src/kernels/cuda/gather.cu +++ b/src/kernels/cuda/gather.cu @@ -28,27 +28,32 @@ __device__ T gatheredOffset2Offset(int gOffset, return offset; } -template -__global__ void _gather_kernel(float *in, float *out, +template +__global__ void _gather_kernel(dataT *in, dataT *out, infini::GatherMetaData metaData, size_t num) { T tid = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - while (tid < num) { + if (tid < num) { T offset = gatheredOffset2Offset(tid, metaData); out[tid] = in[offset]; - tid += stride; } } namespace infini { -void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) { +template +void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) { int blockSize = 32 * 16; int gridSize = (num + blockSize - 1) / blockSize; if (metaData.indexType == DataType::Int64) { - _gather_kernel + _gather_kernel <<>>(in, out, metaData, num); } else { - _gather_kernel<<>>(in, out, metaData, num); + _gather_kernel<<>>(in, out, metaData, num); } } +template void gather_kernel(float *in, float *out, + GatherMetaData metaData, size_t num); +template void gather_kernel(half *in, half *out, GatherMetaData metaData, + size_t num); +template void gather_kernel(int8_t *in, int8_t *out, + GatherMetaData metaData, size_t num); } // namespace infini diff --git a/src/kernels/cuda/gather_elements.cc b/src/kernels/cuda/gather_elements.cc index 795a5c6f..943f0209 100644 --- a/src/kernels/cuda/gather_elements.cc +++ b/src/kernels/cuda/gather_elements.cc @@ -21,8 +21,7 @@ class GatherElementsCuda : public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Float32, - GatherElementsCuda, "GatherELements_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Int32, - GatherElementsCuda, "GatherElements_CUDA_Int32"); +REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, GatherElementsCuda, + "GatherELements_CUDA"); + } // namespace infini diff --git a/src/kernels/cuda/layer_norm.cc b/src/kernels/cuda/layer_norm.cc index a301eb0b..2cd3c786 100644 --- a/src/kernels/cuda/layer_norm.cc +++ b/src/kernels/cuda/layer_norm.cc @@ -24,22 +24,41 @@ class LayerNormCuda : public CudaKernelWithoutConfig { int dimsize = dims[op->getAxis()]; int size = op->getOutput(0)->size(); int scaleSize = op->getInputs(1)->size(); - if (op->numInputs() == 3) { - void *const biasData = (op->getInputs(2)->getRawDataPtr()); - int biasSize = op->getInputs(2)->size(); - // printf("kernel bias:true:%d\n", 1); - LaynormKernel((float *)inputData, (float *)scaleData, eps, size, - scaleSize, dimsize, stride, (float *)outputData, - (float *)biasData, biasSize); + if (op->getDType() == DataType::Float32) { + if (op->numInputs() == 3) { + void *const biasData = + (op->getInputs(2)->getRawDataPtr()); + int biasSize = op->getInputs(2)->size(); + // printf("kernel bias:true:%d\n", 1); + LaynormKernel((float *)inputData, (float *)scaleData, eps, size, + scaleSize, dimsize, stride, (float *)outputData, + (float *)biasData, biasSize); + } else { + // printf("kernel bias:false:%d\n", 0); + LaynormKernel((float *)inputData, (float *)scaleData, eps, size, + scaleSize, dimsize, stride, (float *)outputData); + } + } else if (op->getDType() == DataType::Float16) { + if (op->numInputs() == 3) { + void *const biasData = + (op->getInputs(2)->getRawDataPtr()); + int biasSize = op->getInputs(2)->size(); + // printf("kernel bias:true:%d\n", 1); + LaynormKernel((half *)inputData, (half *)scaleData, eps, size, + scaleSize, dimsize, stride, (half *)outputData, + (half *)biasData, biasSize); + } else { + // printf("kernel bias:false:%d\n", 0); + LaynormKernel((half *)inputData, (half *)scaleData, eps, size, + scaleSize, dimsize, stride, (half *)outputData); + } } else { - // printf("kernel bias:false:%d\n", 0); - LaynormKernel((float *)inputData, (float *)scaleData, eps, size, - scaleSize, dimsize, stride, (float *)outputData); + IT_ASSERT(false); } } }; -REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, DataType::Float32, - LayerNormCuda, "LayerNorm_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, LayerNormCuda, + "LayerNorm_CUDA"); }; // namespace infini diff --git a/src/kernels/cuda/layer_norm.cu b/src/kernels/cuda/layer_norm.cu index c5e6e492..26f06e28 100644 --- a/src/kernels/cuda/layer_norm.cu +++ b/src/kernels/cuda/layer_norm.cu @@ -1,43 +1,41 @@ #include "cuda/cuda_common.h" #include -template +template __launch_bounds__(BLOCK_DIM) __global__ - void blockLaynormKernel(const float *input, const float *scale, - const int dimsize, const int stride, float *output, - const float eps, int scaleSize, const float *bias, - int biasSize) { + void blockLaynormKernel(const T *input, const T *scale, const int dimsize, + const int stride, T *output, const T eps, + int scaleSize, const T *bias, int biasSize) { // len(scale) = len(bias) = dimsize int tmp = blockIdx.x % stride; int tid = (blockIdx.x - tmp) * dimsize + tmp; - float muPartial = 0.0f; + T muPartial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; } - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ float mu; - float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); + __shared__ T mu; + T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory - mu = muBlock / dimsize; + mu = muBlock * static_cast(__fdividef(1.0F, dimsize)); } __syncthreads(); - float sigma2Partial = 0.0f; + T sigma2Partial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { sigma2Partial += (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); } - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; - __shared__ float sigma2; - float sigma2Block = - BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); + __shared__ T sigma2; + T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory - sigma2 = sigma2Block / dimsize; + sigma2 = sigma2Block * static_cast(__fdividef(1.0F, dimsize)); } __syncthreads(); if (biasSize == dimsize) { @@ -47,8 +45,9 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[threadIdx.x + ph * BLOCK_DIM] * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - - mu) / - sqrt(sigma2 + eps) + + mu) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast(sigma2 + eps)))) + bias[threadIdx.x + ph * BLOCK_DIM]; } } else { @@ -57,8 +56,9 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - - mu) / - sqrt(sigma2 + eps) + + mu) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast(sigma2 + eps)))) + bias[threadIdx.x + ph * BLOCK_DIM]; } } @@ -69,8 +69,9 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[threadIdx.x + ph * BLOCK_DIM] * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - - mu) / - sqrt(sigma2 + eps) + + mu) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast(sigma2 + eps)))) + bias[0]; } } else { @@ -79,50 +80,50 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - - mu) / - sqrt(sigma2 + eps) + + mu) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast(sigma2 + eps)))) + bias[0]; } } } } //----------------- -template +template __launch_bounds__(BLOCK_DIM) __global__ - void blockLaynormKernel(const float *input, const float *scale, - const int dimsize, const int stride, float *output, - const float eps, int scaleSize) { + void blockLaynormKernel(const T *input, const T *scale, const int dimsize, + const int stride, T *output, const T eps, + int scaleSize) { // len(scale) = len(bias) = dimsize int tmp = blockIdx.x % stride; int tid = (blockIdx.x - tmp) * dimsize + tmp; - float muPartial = 0.0f; + T muPartial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; } - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ float mu; - float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); + __shared__ T mu; + T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory - mu = muBlock / dimsize; + mu = muBlock * static_cast(__fdividef(1.0F, dimsize)); } __syncthreads(); - float sigma2Partial = 0.0f; + T sigma2Partial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { sigma2Partial += (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); } - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; - __shared__ float sigma2; - float sigma2Block = - BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); + __shared__ T sigma2; + T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory - sigma2 = sigma2Block / dimsize; + sigma2 = sigma2Block * static_cast(__fdividef(1.0F, dimsize)); } __syncthreads(); if (scaleSize == dimsize) { @@ -130,16 +131,18 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[threadIdx.x + ph * BLOCK_DIM] * - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / - sqrt(sigma2 + eps); + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * + static_cast( + __fdividef(1.0F, sqrt(static_cast(sigma2 + eps)))); } } else { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[0] * - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / - sqrt(sigma2 + eps); + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * + static_cast( + __fdividef(1.0F, sqrt(static_cast(sigma2 + eps)))); } } } @@ -158,33 +161,33 @@ __inline__ __device__ T WarpAllReduce(T val) { } return val; } -template -__global__ void warpLaynormKernel(const float *input, const float *scale, +template +__global__ void warpLaynormKernel(const T *input, const T *scale, const int dimsize, const int stride, - float *output, const float eps, int scaleSize, - int otherSize, const float *bias, - int biasSize) { + T *output, const T eps, int scaleSize, + int otherSize, const T *bias, int biasSize) { int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; if (otherIdx < otherSize) { - __shared__ float muTotal[BLOCK_DIM_y]; - __shared__ float sigma2Total[BLOCK_DIM_y]; + __shared__ T muTotal[BLOCK_DIM_y]; + __shared__ T sigma2Total[BLOCK_DIM_y]; - float muPartial = 0.0f; + T muPartial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]; } - muPartial = WarpAllReduce(muPartial); + muPartial = WarpAllReduce(muPartial); if (threadIdx.x == 0) - muTotal[threadIdx.y] = muPartial / dimsize; + muTotal[threadIdx.y] = + muPartial * static_cast(__fdividef(1.0F, dimsize)); //-------------------------------------------- - float sigma2Partial = 0.0f; + T sigma2Partial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { sigma2Partial += @@ -194,10 +197,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, muTotal[threadIdx.y]); } - sigma2Partial = WarpAllReduce(sigma2Partial); + sigma2Partial = WarpAllReduce(sigma2Partial); if (threadIdx.x == 0) - sigma2Total[threadIdx.y] = sigma2Partial / dimsize; + sigma2Total[threadIdx.y] = + sigma2Partial * static_cast(__fdividef(1.0F, dimsize)); //-------------------------------------------- if (biasSize == dimsize) { @@ -209,8 +213,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, scale[threadIdx.x + ph * BLOCK_DIM_x] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps) + + muTotal[threadIdx.y]) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))) + bias[threadIdx.x + ph * BLOCK_DIM_x]; } } else { @@ -221,8 +227,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps) + + muTotal[threadIdx.y]) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))) + bias[threadIdx.x + ph * BLOCK_DIM_x]; } } @@ -235,8 +243,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, scale[threadIdx.x + ph * BLOCK_DIM_x] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps) + + muTotal[threadIdx.y]) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))) + bias[0]; } } else { @@ -247,40 +257,43 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps) + + muTotal[threadIdx.y]) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))) + bias[0]; } } } } } -template -__global__ void warpLaynormKernel(const float *input, const float *scale, +template +__global__ void warpLaynormKernel(const T *input, const T *scale, const int dimsize, const int stride, - float *output, const float eps, int scaleSize, + T *output, const T eps, int scaleSize, int otherSize) { int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; if (otherIdx < otherSize) { - __shared__ float muTotal[BLOCK_DIM_y]; - __shared__ float sigma2Total[BLOCK_DIM_y]; + __shared__ T muTotal[BLOCK_DIM_y]; + __shared__ T sigma2Total[BLOCK_DIM_y]; - float muPartial = 0.0f; + T muPartial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]; } - muPartial = WarpAllReduce(muPartial); + muPartial = WarpAllReduce(muPartial); if (threadIdx.x == 0) - muTotal[threadIdx.y] = muPartial / dimsize; + muTotal[threadIdx.y] = + muPartial * static_cast(__fdividef(1.0F, dimsize)); //-------------------------------------------- - float sigma2Partial = 0.0f; + T sigma2Partial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { sigma2Partial += @@ -290,10 +303,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, muTotal[threadIdx.y]); } - sigma2Partial = WarpAllReduce(sigma2Partial); + sigma2Partial = WarpAllReduce(sigma2Partial); if (threadIdx.x == 0) - sigma2Total[threadIdx.y] = sigma2Partial / dimsize; + sigma2Total[threadIdx.y] = + sigma2Partial * static_cast(__fdividef(1.0F, dimsize)); //-------------------------------------------- if (scaleSize == dimsize) { @@ -302,8 +316,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = scale[threadIdx.x + ph * BLOCK_DIM_x] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps); + muTotal[threadIdx.y]) * + static_cast( + __fdividef(1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))); } } else { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { @@ -311,8 +327,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps); + muTotal[threadIdx.y]) * + static_cast( + __fdividef(1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))); } } } @@ -325,7 +343,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, if (dimsize > 1024) { int BLOCK_DIM = 1024; - blockLaynormKernel<1024> + blockLaynormKernel <<>>(input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize); } else if (dimsize > 31) { @@ -335,7 +353,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<32, 32><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 15) { @@ -345,7 +363,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<16, 64><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 7) { @@ -355,7 +373,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<8, 128><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else { @@ -365,7 +383,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<4, 256><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } @@ -378,7 +396,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, if (dimsize > 1024) { int BLOCK_DIM = 1024; - blockLaynormKernel<1024><<>>( + blockLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; @@ -387,7 +405,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<32, 32><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; @@ -396,7 +414,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<16, 64><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; @@ -405,7 +423,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<8, 128><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else { int BLOCK_DIM_x = 4; @@ -414,7 +432,108 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<4, 256><<>>( + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block); + } +} +//----------------- +void LaynormKernel(const half *input, const half *scale, const half eps, + int size, int scaleSize, const int dimsize, const int stride, + half *output, const half *bias, int biasSize) { + int num_block = size / dimsize; + if (dimsize > 1024) { + int BLOCK_DIM = 1024; + + blockLaynormKernel + <<>>(input, scale, dimsize, stride, output, + eps, scaleSize, bias, biasSize); + } else if (dimsize > 31) { + int BLOCK_DIM_x = 32; + int BLOCK_DIM_y = 32; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block, + bias, biasSize); + } else if (dimsize > 15) { + int BLOCK_DIM_x = 16; + int BLOCK_DIM_y = 64; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block, + bias, biasSize); + } else if (dimsize > 7) { + int BLOCK_DIM_x = 8; + int BLOCK_DIM_y = 128; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block, + bias, biasSize); + } else { + int BLOCK_DIM_x = 4; + int BLOCK_DIM_y = 256; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block, + bias, biasSize); + } +} + +void LaynormKernel(const half *input, const half *scale, const half eps, + int size, int scaleSize, const int dimsize, const int stride, + half *output) { + int num_block = size / dimsize; + if (dimsize > 1024) { + int BLOCK_DIM = 1024; + + blockLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize); + } else if (dimsize > 31) { + int BLOCK_DIM_x = 32; + int BLOCK_DIM_y = 32; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block); + } else if (dimsize > 15) { + int BLOCK_DIM_x = 16; + int BLOCK_DIM_y = 64; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block); + } else if (dimsize > 7) { + int BLOCK_DIM_x = 8; + int BLOCK_DIM_y = 128; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block); + } else { + int BLOCK_DIM_x = 4; + int BLOCK_DIM_y = 256; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block); } } diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index 2d457cbc..e2addde1 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -2,6 +2,7 @@ #include "core/kernel.h" #include "cuda/cuda_expand.h" #include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" #include "utils/small_array.h" namespace infini { @@ -48,11 +49,12 @@ class matmulCublas : public Kernel { auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N; const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n, ldc = n; - float alpha = 1.f, beta = 0.f; - if (op->numInputs() == 2) { // no bias - beta = 0.f; - } else { // broadcast bias to output - beta = 1.f; + float alpha_naive = 1.f, beta_naive = 0.f; + auto dataType = op->getDType(); + auto cuDataType = cublasDataTypeConvert(dataType); + IT_ASSERT(cuDataType != CUDA_R_8I, "matmul don't support int8 dtype."); + if (op->numInputs() == 3) { // have bias + beta_naive = 1.f; auto inC = op->getInputs(2); auto out = op->getOutput(); SmallArray inputShape, outputShape; @@ -69,8 +71,9 @@ class matmulCublas : public Kernel { if (i >= offset) inputShape.data[i] = inC->getDims()[i - offset]; } - expandKernel(inC->getRawDataPtr(), - out->getRawDataPtr(), nDims, outputsize, + const int dType = dataType.getIndex(); + expandKernel(dType, inC->getRawDataPtr(), + out->getRawDataPtr(), nDims, outputsize, inputShape, outputShape); } // TODO:use compute type @@ -89,16 +92,38 @@ class matmulCublas : public Kernel { (dimB == 3 && op->getInputs(1)->getDims()[0] == 1)) ? 0 // Broadcast the batch dimension if batch size is 1 : n * k; - stat = cublasGemmStridedBatchedEx( - context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData, - CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA, - &beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, - (cublasGemmAlgo_t)record->algo); + if (dataType == DataType::Float16) { + half alpha_half = static_cast(alpha_naive); + half beta_half = static_cast(beta_naive); + stat = cublasGemmStridedBatchedEx( + context->cublasHandle(), opB, opA, n, m, k, &alpha_half, + inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda, + strideA, &beta_half, outData, cuDataType, ldc, m * n, b, + cuDataType, (cublasGemmAlgo_t)record->algo); + + } else { + stat = cublasGemmStridedBatchedEx( + context->cublasHandle(), opB, opA, n, m, k, &alpha_naive, + inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda, + strideA, &beta_naive, outData, cuDataType, ldc, m * n, b, + cuDataType, (cublasGemmAlgo_t)record->algo); + } } else { - stat = cublasGemmEx( - context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData, - CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData, - CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo); + if (dataType == DataType::Float16) { + half alpha_half = static_cast(alpha_naive); + half beta_half = static_cast(beta_naive); + stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k, + &alpha_half, inBData, cuDataType, ldb, + inAData, cuDataType, lda, &beta_half, + outData, cuDataType, ldc, cuDataType, + (cublasGemmAlgo_t)record->algo); + } else { + stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k, + &alpha_naive, inBData, cuDataType, ldb, + inAData, cuDataType, lda, &beta_naive, + outData, cuDataType, ldc, cuDataType, + (cublasGemmAlgo_t)record->algo); + } } // if (stat != CUBLAS_STATUS_SUCCESS) // cout << cublasGetErrorString(stat); @@ -140,8 +165,9 @@ class matmulCublas : public Kernel { } }; -REGISTER_KERNEL(Device::CUDA, OpType::MatMul, DataType::Float32, matmulCublas, - "Matmul_cuBLAS_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::MatMul, matmulCublas, + "Matmul_cuBLAS_CUDA"); REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json); + }; // namespace infini diff --git a/src/kernels/cuda/membound_tvm_extract_source.cc b/src/kernels/cuda/membound_tvm_extract_source.cc index e4b76e60..57e5b9d1 100644 --- a/src/kernels/cuda/membound_tvm_extract_source.cc +++ b/src/kernels/cuda/membound_tvm_extract_source.cc @@ -229,9 +229,8 @@ class MemboundTVMExtractSource : public Kernel { } }; -// REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, -// MemboundTVMExtractSource, -// "Memobund_TVM_Ansor_extract_source"); +REGISTER_KERNEL(Device::CUDA, OpType::MemBound, MemboundTVMExtractSource, + "Memobund_TVM_Ansor_extract_source"); }; // namespace infini #endif diff --git a/src/kernels/cuda/membound_tvm_packed_function.cc b/src/kernels/cuda/membound_tvm_packed_function.cc index 8086518d..a8af951e 100644 --- a/src/kernels/cuda/membound_tvm_packed_function.cc +++ b/src/kernels/cuda/membound_tvm_packed_function.cc @@ -216,9 +216,9 @@ class MemboundTVMPackedFunction : public Kernel { } }; -REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, - MemboundTVMPackedFunction, +REGISTER_KERNEL(Device::CUDA, OpType::MemBound, MemboundTVMPackedFunction, "Memobund_TVM_Ansor_packed_funciton"); + }; // namespace infini #endif diff --git a/src/kernels/cuda/pad_slice.cc b/src/kernels/cuda/pad_slice.cc index 1ff4dffa..0039c11c 100644 --- a/src/kernels/cuda/pad_slice.cc +++ b/src/kernels/cuda/pad_slice.cc @@ -39,10 +39,8 @@ class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda, - "Slice__CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Int64, SliceCuda, - "Slice__CUDA_Int64"); -REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda, - "Pad__CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Slice, SliceCuda, "Slice__CUDA"); + +REGISTER_KERNEL(Device::CUDA, OpType::Pad, PadCuda, "Pad__CUDA"); + } // namespace infini diff --git a/src/kernels/cuda/pad_slice.cu b/src/kernels/cuda/pad_slice.cu index cd6bc37b..ccf85748 100644 --- a/src/kernels/cuda/pad_slice.cu +++ b/src/kernels/cuda/pad_slice.cu @@ -1,6 +1,7 @@ #include "core/data_type.h" #include "cuda/cuda_common.h" #include "cuda/cuda_pad_slice.h" +#include "cuda/cuda_utility.h" __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset, TransMetaData metaData, @@ -21,39 +22,83 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset, } template -__global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData, - int nDims, int num, bool isPad) { +__global__ void _pad_slice_kernel(void *part, void *whole, + TransMetaData metaData, int nDims, int num, + bool isPad) { int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= num) + if (tid >= num) { return; + } int stride = blockDim.x * gridDim.x; while (tid < num) { int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims); - if (isPad) - if (offset < 0) - whole[tid] = 0; - else - whole[tid] = part[offset]; - else if (offset >= 0) - part[offset] = whole[tid]; + if (isPad) { + if (offset < 0) { + ((T *)whole)[tid] = static_cast(0.f); + } else { + ((T *)whole)[tid] = ((T *)part)[offset]; + } + } else if (offset >= 0) { + ((T *)part)[offset] = ((T *)whole)[tid]; + } tid += stride; } } namespace infini { +#define CASE(T) \ + _pad_slice_kernel::t><<>>( \ + partData, wholeData, metadata, nDims, num, isPad); + +#define SWITCH_DTYPE(DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE(1) \ + break; \ + case 2: \ + CASE(2) \ + break; \ + case 3: \ + CASE(3) \ + break; \ + case 4: \ + CASE(4) \ + break; \ + case 5: \ + CASE(5) \ + break; \ + case 6: \ + CASE(6) \ + break; \ + case 7: \ + CASE(7) \ + break; \ + case 10: \ + CASE(10) \ + break; \ + case 11: \ + CASE(11) \ + break; \ + case 12: \ + CASE(12) \ + break; \ + case 13: \ + CASE(13) \ + break; \ + case 16: \ + CASE(16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } + void pad_slice_kernel(void *partData, void *wholeData, const TransMetaData &metadata, int nDims, int num, bool isPad) { int blockSize = 32 * 16; int gridSize = (num + blockSize - 1) / blockSize; - if (metadata.DType == DataType::Int64.getIndex()) { - _pad_slice_kernel - <<>>((int64_t *)partData, (int64_t *)wholeData, - metadata, nDims, num, isPad); - } else if (metadata.DType == DataType::Float32.getIndex()) { - _pad_slice_kernel<<>>( - (float *)partData, (float *)wholeData, metadata, nDims, num, isPad); - } + int dType = metadata.DType; + SWITCH_DTYPE(dType) } } // namespace infini diff --git a/src/kernels/cuda/pooling.cc b/src/kernels/cuda/pooling.cc index d8b2e0f8..03d5b883 100644 --- a/src/kernels/cuda/pooling.cc +++ b/src/kernels/cuda/pooling.cc @@ -8,6 +8,7 @@ class poolingCudnn : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const inData = (op->getInputs(0)->getRawDataPtr()); void *const outData = (op->getOutput()->getRawDataPtr()); @@ -76,8 +77,9 @@ class avgPoolCudnn : public poolingCudnn { } }; -REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn, - "MaxPool_cuDNN_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, DataType::Float32, - avgPoolCudnn, "AvgPool_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, maxPoolCudnn, + "MaxPool_cuDNN_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, avgPoolCudnn, + "AvgPool_cuDNN_CUDA"); + }; // namespace infini diff --git a/src/kernels/cuda/recv.cc b/src/kernels/cuda/recv.cc index 7fd7ee49..42c9073e 100644 --- a/src/kernels/cuda/recv.cc +++ b/src/kernels/cuda/recv.cc @@ -40,8 +40,7 @@ class RecvNCCL : public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::Recv, DataType::Float32, RecvNCCL, - "Recv_NCCL_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Recv, RecvNCCL, "Recv_NCCL_CUDA"); } // namespace infini #endif diff --git a/src/kernels/cuda/reduce.cc b/src/kernels/cuda/reduce.cc index 840a572f..531c09d0 100644 --- a/src/kernels/cuda/reduce.cc +++ b/src/kernels/cuda/reduce.cc @@ -1,6 +1,7 @@ #include "operators/reduce.h" #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" namespace infini { class ReduceCudnnBase : public CudaKernelWithoutConfig { @@ -46,12 +47,12 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig { checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); cudnnTensorDescriptor_t outDesc; checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); + auto cudnnDataType = cudnnDataTypeConvert(op->getDType()); if (nInDims > 3) { checkCudnnError(cudnnSetTensorNdDescriptor( - inDesc, CUDNN_DATA_FLOAT, nInDims, inDimArray, inStrideArray)); - checkCudnnError( - cudnnSetTensorNdDescriptor(outDesc, CUDNN_DATA_FLOAT, nInDims, - outDimArray, outStrideArray)); + inDesc, cudnnDataType, nInDims, inDimArray, inStrideArray)); + checkCudnnError(cudnnSetTensorNdDescriptor( + outDesc, cudnnDataType, nInDims, outDimArray, outStrideArray)); } else { int idims[4] = {1, 1, 1, 1}, odims[4] = {1, 1, 1, 1}; for (int i = 0; i < nInDims; ++i) { @@ -62,20 +63,19 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig { } checkCudnnError(cudnnSetTensor4dDescriptor( - inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, idims[0], idims[1], + inDesc, CUDNN_TENSOR_NCHW, cudnnDataType, idims[0], idims[1], idims[2], idims[3])); checkCudnnError(cudnnSetTensor4dDescriptor( - outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, odims[0], - odims[1], odims[2], odims[3])); + outDesc, CUDNN_TENSOR_NCHW, cudnnDataType, odims[0], odims[1], + odims[2], odims[3])); } // get reduce descriptor cudnnReduceTensorDescriptor_t reduceDesc; checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc)); checkCudnnError(cudnnSetReduceTensorDescriptor( - reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT, - CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, - CUDNN_32BIT_INDICES)); + reduceDesc, getReduceOp(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN, + CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES)); // get workspace size_t workspaceSize = 0; @@ -120,8 +120,9 @@ class ReduceSumCudnn : public ReduceCudnnBase { } }; -REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32, - ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, DataType::Float32, - ReduceSumCudnn, "ReduceSum_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, ReduceMeanCudnn, + "ReduceMean_cuDNN_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, ReduceSumCudnn, + "ReduceSum_cuDNN_CUDA"); + }; // namespace infini diff --git a/src/kernels/cuda/reshape.cc b/src/kernels/cuda/reshape.cc index 232bcdf6..450105b0 100644 --- a/src/kernels/cuda/reshape.cc +++ b/src/kernels/cuda/reshape.cc @@ -11,19 +11,12 @@ class CopyCuda : public CudaKernelWithoutConfig { } }; // reshape/flatten/identity all act as copying from input to output. -REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda, - "Reshape_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda, - "Reshape_CUDA_Int64"); -REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda, - "Reshape_CUDA_Int32"); -REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda, - "Flatten_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, DataType::Float32, CopyCuda, - "Squeeze_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, DataType::Float32, CopyCuda, - "Unsqueeze_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda, - "Identity_CUDA_Float32"); + +REGISTER_KERNEL(Device::CUDA, OpType::Reshape, CopyCuda, "Reshape_CUDA"); + +REGISTER_KERNEL(Device::CUDA, OpType::Flatten, CopyCuda, "Flatten_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Identity, CopyCuda, "Identity_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, CopyCuda, "Squeeze_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, CopyCuda, "Unsqueeze_CUDA"); } // namespace infini diff --git a/src/kernels/cuda/resize.cc b/src/kernels/cuda/resize.cc index 5becb913..106b46f3 100644 --- a/src/kernels/cuda/resize.cc +++ b/src/kernels/cuda/resize.cc @@ -6,6 +6,7 @@ class ResizeCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto in = op->getInputs(0); auto out = op->getOutputs()[0]; @@ -48,7 +49,6 @@ class ResizeCuda : public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::Resize, DataType::Float32, ResizeCuda, - "Resize_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Resize, ResizeCuda, "Resize_CUDA"); } // namespace infini diff --git a/src/kernels/cuda/send.cc b/src/kernels/cuda/send.cc index 38684062..6f8af9aa 100644 --- a/src/kernels/cuda/send.cc +++ b/src/kernels/cuda/send.cc @@ -36,8 +36,7 @@ class SendNCCL : public CudaKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::CUDA, OpType::Send, DataType::Float32, SendNCCL, - "Send_NCCL_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Send, SendNCCL, "Send_NCCL_CUDA"); } // namespace infini #endif diff --git a/src/kernels/cuda/softmax.cc b/src/kernels/cuda/softmax.cc index 024288c2..4a2d844b 100644 --- a/src/kernels/cuda/softmax.cc +++ b/src/kernels/cuda/softmax.cc @@ -20,11 +20,17 @@ class SoftmaxCuda : public CudaKernelWithoutConfig { int stride = op->getInputs(0)->getStride().at(op->getAxis()); int num_blocks = size / dimsize; - softmax_kernel(num_blocks, (float *)input, (float *)output, size, - dimsize, stride); + if (op->getDType() == DataType::Float32) { + softmax_kernel(num_blocks, (float *)input, (float *)output, size, + dimsize, stride); + } else if (op->getDType() == DataType::Float16) { + softmax_kernel(num_blocks, (half *)input, (half *)output, size, + dimsize, stride); + } else { + IT_ASSERT(false); + } } }; -REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCuda, - "Softmax_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Softmax, SoftmaxCuda, "Softmax_CUDA"); } // namespace infini diff --git a/src/kernels/cuda/softmax.cu b/src/kernels/cuda/softmax.cu index 7e85ec43..69334d50 100644 --- a/src/kernels/cuda/softmax.cu +++ b/src/kernels/cuda/softmax.cu @@ -1,6 +1,5 @@ #include "cuda/cuda_common.h" #include - struct __align__(8) DataMaxSum { // update the global max and sum, store the // output at max_tmp and sum_tmp float max_tmp; // store max @@ -16,9 +15,9 @@ __device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a, return bigger; } -template +template __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel( - float *__restrict input, float *__restrict output, int size, int dimsize, + T *__restrict input, T *__restrict output, int size, int dimsize, int stride) { // if set axis = 1, inputShape=[I,J,K,S] // tid = i(JKS) + j(KS) + k(S) + s @@ -33,15 +32,33 @@ __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel( dms_partial.max_tmp = -__FLT_MAX__; dms_partial.sum_tmp = 0.0f; DataMaxSum dms_input; - for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + int remain = dimsize % BLOCK_DIM; + int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread - dms_input.max_tmp = - input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; + if (threadIdx.x < remain) { + for (int ind = 0; ind < step; ind++) { + dms_input.max_tmp = + input[tid + (threadIdx.x * step + ind) * stride]; - dms_input.sum_tmp = 1.0f; - dms_partial = reduce_dms_op(dms_partial, - dms_input); // reduce the data to one block + dms_input.sum_tmp = 1.0f; + dms_partial = + reduce_dms_op(dms_partial, + dms_input); // reduce the data to one block + } + } else { + for (int ind = 0; ind < step - 1; ind++) { + dms_input.max_tmp = + input[tid + (remain * step + + (threadIdx.x - remain) * (step - 1) + ind) * + stride]; + + dms_input.sum_tmp = 1.0f; + dms_partial = + reduce_dms_op(dms_partial, + dms_input); // reduce the data to one block + } } + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ DataMaxSum dms_total; @@ -53,12 +70,102 @@ __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel( } __syncthreads(); //----------------- + if (threadIdx.x < remain) { + for (int ind = 0; ind < step; ind++) { - for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { - output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = - __expf(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - - dms_total.max_tmp) * - __fdividef(1.0F, dms_total.sum_tmp); + output[tid + (threadIdx.x * step + ind) * stride] = + __expf(static_cast( + input[tid + (threadIdx.x * step + ind) * stride]) - + dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp); + } + } else { + for (int ind = 0; ind < step - 1; ind++) { + + output[tid + + (remain * step + (threadIdx.x - remain) * (step - 1) + ind) * + stride] = + __expf(static_cast( + input[tid + + (remain * step + + (threadIdx.x - remain) * (step - 1) + ind) * + stride]) - + dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp); + } + } +} + +template +__global__ void +_blockSoftmaxKernel(T *__restrict input, T *__restrict output, int size, + int dimsize, + int stride) { // if set axis = 1, inputShape=[I,J,K,S] + // tid = i(JKS) + j(KS) + k(S) + s + + // blockDim.x = size/dimsize = IKS + // blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s + + int tid = + blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) * + dimsize; // now, tid = i(JKS) + k(S) + s; + int remain = dimsize % BLOCK_DIM; + int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread + float dataPerThread[numPerThread]; + + DataMaxSum dms_partial; + dms_partial.max_tmp = -__FLT_MAX__; + dms_partial.sum_tmp = 0.0f; + DataMaxSum dms_input; + if (threadIdx.x < remain) { + for (int ind = 0; ind < step; ind++) { + dataPerThread[ind] = + input[tid + (threadIdx.x * step + ind) * stride]; + dms_input.max_tmp = dataPerThread[ind]; + dms_input.sum_tmp = 1.0f; + dms_partial = + reduce_dms_op(dms_partial, + dms_input); // reduce the data to one block + } + } else { + for (int ind = 0; ind < step - 1; ind++) { + dataPerThread[ind] = + input[tid + (remain * step + + (threadIdx.x - remain) * (step - 1) + ind) * + stride]; + dms_input.max_tmp = dataPerThread[ind]; + dms_input.sum_tmp = 1.0f; + dms_partial = + reduce_dms_op(dms_partial, + dms_input); // reduce the data to one block + } + } + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ DataMaxSum dms_total; + DataMaxSum dms_block = + BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op); + if (threadIdx.x == + 0) { // must set threadIdx.x = 0 write the output to memory + dms_total = dms_block; + } + __syncthreads(); + //----------------- + if (threadIdx.x < remain) { + for (int ind = 0; ind < step; ind++) { + output[tid + (threadIdx.x * step + ind) * stride] = + __expf(dataPerThread[ind] - dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp); + } + } else { + for (int ind = 0; ind < step - 1; ind++) { + output[tid + + (remain * step + (threadIdx.x - remain) * (step - 1) + ind) * + stride] = + __expf(dataPerThread[ind] - dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp); + } } } @@ -81,14 +188,14 @@ __inline__ __device__ T WarpAllReduce(T val) { } return val; } -template -__global__ void _warpSoftmaxKernel(float *__restrict input, - float *__restrict output, int size, - int dimsize, int stride) { + +template +__global__ void _warpSoftmaxKernel(T *__restrict input, T *__restrict output, + int size, int dimsize, int stride) { int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int otherSize = size / dimsize; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; - + float dataPerThreadx[numPerThreadx]; if (otherIdx < otherSize) { __shared__ float max_total[BLOCK_DIM_y]; @@ -96,9 +203,9 @@ __global__ void _warpSoftmaxKernel(float *__restrict input, float max_data = -__FLT_MAX__; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { - max_data = - max(max_data, - input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]); + dataPerThreadx[ph] = + input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]; + max_data = max(max_data, dataPerThreadx[ph]); } max_data = WarpAllReduce(max_data); @@ -110,9 +217,9 @@ __global__ void _warpSoftmaxKernel(float *__restrict input, float sum_data = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { - sum_data += - __expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - max_total[threadIdx.y]); + dataPerThreadx[ph] = + __expf(dataPerThreadx[ph] - max_total[threadIdx.y]); + sum_data += dataPerThreadx[ph]; } sum_data = WarpAllReduce(sum_data); @@ -124,9 +231,7 @@ __global__ void _warpSoftmaxKernel(float *__restrict input, for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = - __expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - max_total[threadIdx.y]) * - __fdividef(1.0F, sum_total[threadIdx.y]); + dataPerThreadx[ph] * __fdividef(1.0F, sum_total[threadIdx.y]); } } } @@ -137,10 +242,35 @@ namespace infini { void softmax_kernel(int num_blocks, float *input, float *output, int size, int dimsize, int stride) { - if (dimsize > 1024) { + if (dimsize > 1024 * 128) { int BLOCK_DIM = 1024; - _blockSoftmaxKernel<1024> + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024 * 64) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024 * 32) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024 * 16) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024 * 4) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel <<>>(input, output, size, dimsize, stride); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; @@ -149,7 +279,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - _warpSoftmaxKernel<32, 32> + _warpSoftmaxKernel <<>>(input, output, size, dimsize, stride); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; @@ -158,7 +288,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - _warpSoftmaxKernel<16, 64> + _warpSoftmaxKernel <<>>(input, output, size, dimsize, stride); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; @@ -167,7 +297,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - _warpSoftmaxKernel<8, 128> + _warpSoftmaxKernel <<>>(input, output, size, dimsize, stride); } else { int BLOCK_DIM_x = 4; @@ -176,7 +306,79 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - _warpSoftmaxKernel<4, 256> + _warpSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } +} +//------------------ +void softmax_kernel(int num_blocks, half *input, half *output, int size, + int dimsize, int stride) { + + if (dimsize > 1024 * 128) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024 * 64) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024 * 32) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024 * 16) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024 * 4) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 1024) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 31) { + int BLOCK_DIM_x = 32; + int BLOCK_DIM_y = 32; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + _warpSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 15) { + int BLOCK_DIM_x = 16; + int BLOCK_DIM_y = 64; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + _warpSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else if (dimsize > 7) { + int BLOCK_DIM_x = 8; + int BLOCK_DIM_y = 128; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + _warpSoftmaxKernel + <<>>(input, output, size, dimsize, stride); + } else { + int BLOCK_DIM_x = 4; + int BLOCK_DIM_y = 256; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + _warpSoftmaxKernel <<>>(input, output, size, dimsize, stride); } } diff --git a/src/kernels/cuda/split_concat.cc b/src/kernels/cuda/split_concat.cc index d3f8a551..e06ef731 100644 --- a/src/kernels/cuda/split_concat.cc +++ b/src/kernels/cuda/split_concat.cc @@ -7,7 +7,8 @@ namespace infini { class CudaCompute { - void initComposedTensorMetadata(ComposedTensorMetadata &metadata, + template + void initComposedTensorMetadata(ComposedTensorMetadata &metadata, Tensor tensor) const { int nDims = tensor->getRank(); auto strides = tensor->getStride(); @@ -16,10 +17,10 @@ class CudaCompute { metadata.dimSize[i] = tensor->getDims().at(i); metadata.stride[i] = strides.at(i); } - metadata.data = tensor->getRawDataPtr(); + metadata.data = tensor->getRawDataPtr(); } - - void initElementTensorMetadata(ElementTensorMetadata &metadata, + template + void initElementTensorMetadata(ElementTensorMetadata &metadata, TensorVec tensors, int idx, int dim, int &dimBgIdx, int &batchCounter) const { int nTensors = tensors.size(); @@ -27,7 +28,7 @@ class CudaCompute { ++batchCounter) { auto tensor = tensors.at(idx + batchCounter); auto dimSize = tensor->getDims()[dim]; - metadata.data[batchCounter] = tensor->getRawDataPtr(); + metadata.data[batchCounter] = tensor->getRawDataPtr(); metadata.dimBgNo[batchCounter] = dimBgIdx; metadata.dimSize[batchCounter] = dimSize; metadata.nElements[batchCounter] = tensor->size(); @@ -36,17 +37,17 @@ class CudaCompute { } public: + template void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim, int nDims, bool isSplit) const { IT_ASSERT(nDims <= DIM_MAX_SIZE); - - ComposedTensorMetadata composedMetadata; - initComposedTensorMetadata(composedMetadata, composedTensor); + ComposedTensorMetadata composedMetadata; + initComposedTensorMetadata(composedMetadata, composedTensor); int dimBgNo = 0; int nElemets = elementsTensor.size(); for (int i = 0; i < nElemets; i += BATCH_SIZE) { - ElementTensorMetadata elemMetadata; + ElementTensorMetadata elemMetadata; int batchCounter = 0; initElementTensorMetadata(elemMetadata, elementsTensor, i, dim, dimBgNo, batchCounter); @@ -74,23 +75,38 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig { } } } - do_compute(_op->getOutput(), _op->getInputs(), - as(_op)->getDim(), _op->getOutput()->getRank(), - false); + if (_op->getDType() == DataType::Float32) { + do_compute(_op->getOutput(), _op->getInputs(), + as(_op)->getDim(), + _op->getOutput()->getRank(), false); + } else if (_op->getDType() == DataType::Float16) { + do_compute(_op->getOutput(), _op->getInputs(), + as(_op)->getDim(), + _op->getOutput()->getRank(), false); + } else { + IT_ASSERT(false); + } } }; class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { - do_compute(_op->getInputs(0), _op->getOutputs(), - as(_op)->getDim(), _op->getInputs(0)->getRank(), - true); + if (_op->getDType() == DataType::Float32) { + do_compute(_op->getInputs(0), _op->getOutputs(), + as(_op)->getDim(), + _op->getInputs(0)->getRank(), true); + } else if (_op->getDType() == DataType::Float16) { + do_compute(_op->getInputs(0), _op->getOutputs(), + as(_op)->getDim(), + _op->getInputs(0)->getRank(), true); + } else { + IT_ASSERT(false); + } } }; -REGISTER_KERNEL(Device::CUDA, OpType::Concat, DataType::Float32, ConcatCuda, - "Concat_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Split, DataType::Float32, SplitCuda, - "Split_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Concat, ConcatCuda, "Concat_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Split, SplitCuda, "Split_CUDA"); + } // namespace infini diff --git a/src/kernels/cuda/split_concat.cu b/src/kernels/cuda/split_concat.cu index 193501e0..fdb5f18c 100644 --- a/src/kernels/cuda/split_concat.cu +++ b/src/kernels/cuda/split_concat.cu @@ -1,9 +1,9 @@ #include "cuda/cuda_common.h" #include "cuda/cuda_split_concat.h" - +template __host__ __device__ int elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim, - int nDim, ComposedTensorMetadata wholeMeta) { + int nDim, ComposedTensorMetadata wholeMeta) { int offset = 0; // COMP(x0,...,xk,...,xn-1) = ELMT[xk / d](x0,...,xk % d,...xn-1) @@ -25,10 +25,10 @@ elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim, int oP = (dim == 0) ? (elementIndex + dimBgNo) : elementIndex; return offset + oP * wholeMeta.stride[0]; } - -__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, - ComposedTensorMetadata compMeta, int dim, - int nDims, bool isSplit) { +template +__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, + ComposedTensorMetadata compMeta, + int dim, int nDims, bool isSplit) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int nElements = elemMeta.nElements[blockIdx.y]; if (tid >= nElements) @@ -36,10 +36,10 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, auto dimBgNo = elemMeta.dimBgNo[blockIdx.y]; auto dimSize = elemMeta.dimSize[blockIdx.y]; - float *elemData = elemMeta.data[blockIdx.y]; + T *elemData = elemMeta.data[blockIdx.y]; int Offset = - elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta); + elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta); // copy data from input to output // for split:input is composed tensor;for concat:input is element // tensors. @@ -52,8 +52,22 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, namespace infini { // TODO: when dim=0, the operation can be executed in-place -void split_concat_kernel(const ElementTensorMetadata &eleMeta, - const ComposedTensorMetadata &compMeta, int dim, +void split_concat_kernel(const ElementTensorMetadata &eleMeta, + const ComposedTensorMetadata &compMeta, int dim, + int batchSize, int nDims, bool isSplit) { + dim3 blockSize = dim3(32 * 16); + // gridsize = max_n_elements / blockSize + int max_n_elements = + *std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize); + int gridDimX = (max_n_elements - 1) / (32 * 16) + 1; + // each y is a split among the batch + dim3 gridSize(gridDimX, batchSize); + + _split_concat_kernel<<>>(eleMeta, compMeta, dim, nDims, + isSplit); +} +void split_concat_kernel(const ElementTensorMetadata &eleMeta, + const ComposedTensorMetadata &compMeta, int dim, int batchSize, int nDims, bool isSplit) { dim3 blockSize = dim3(32 * 16); // gridsize = max_n_elements / blockSize diff --git a/src/kernels/cuda/transpose.cc b/src/kernels/cuda/transpose.cc index 774cb37f..b22ee3dd 100644 --- a/src/kernels/cuda/transpose.cc +++ b/src/kernels/cuda/transpose.cc @@ -38,8 +38,9 @@ class TransposeCuda : public CudaKernelWithoutConfig { outputDims.data[i] = outputShape[i]; } - transpose_kernel((float *)inputData, (float *)outputData, nDims, size, - strides, outputDims); + const int dType = op->getDType().getIndex(); + transpose_kernel(dType, inputData, outputData, nDims, size, strides, + outputDims); } }; @@ -82,15 +83,16 @@ class DepthToSpaceCuda : public CudaKernelWithoutConfig { for (int i = 0; i < nDims; ++i) { outputDims.data[i] = transpose[i]; } - - transpose_kernel((float *)inputData, (float *)outputData, nDims, size, - strides, outputDims); + const int dType = op->getDType().getIndex(); + transpose_kernel(dType, inputData, outputData, nDims, size, strides, + outputDims); } }; -REGISTER_KERNEL(Device::CUDA, OpType::Transpose, DataType::Float32, - TransposeCuda, "Transpose_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Transpose, TransposeCuda, + "Transpose_CUDA"); + +REGISTER_KERNEL(Device::CUDA, OpType::DepthToSpace, DepthToSpaceCuda, + "DepthToSpace_CUDA"); -REGISTER_KERNEL(Device::CUDA, OpType::DepthToSpace, DataType::Float32, - DepthToSpaceCuda, "DepthToSpace_CUDA_Float32"); } // namespace infini diff --git a/src/kernels/cuda/transpose.cu b/src/kernels/cuda/transpose.cu index f753217c..917afde3 100644 --- a/src/kernels/cuda/transpose.cu +++ b/src/kernels/cuda/transpose.cu @@ -1,12 +1,14 @@ #include "core/common.h" #include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" #include "utils/small_array.h" constexpr unsigned int num_threads() { return 32 * 4; } constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } -__global__ void _transpose_kernel(float *input, float *output, int nDims, +template +__global__ void _transpose_kernel(void *input, void *output, int nDims, int size, infini::SmallArray strides, infini::SmallArray outputShape) { int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; @@ -17,21 +19,61 @@ __global__ void _transpose_kernel(float *input, float *output, int nDims, inputIdx += v % outputShape.data[i] * strides.data[i]; v /= outputShape.data[i]; } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - output[outputIdx] = __ldg(input + inputIdx); -#else - output[outputIdx] = input[inputIdx]; -#endif + ((T *)output)[outputIdx] = ((T *)input)[inputIdx]; } } +#define CASE(T) \ + _transpose_kernel::t><<>>( \ + input, output, nDims, size, strides, outputShape); + +#define SWITCH_DTYPE(DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE(1) \ + break; \ + case 2: \ + CASE(2) \ + break; \ + case 3: \ + CASE(3) \ + break; \ + case 4: \ + CASE(4) \ + break; \ + case 5: \ + CASE(5) \ + break; \ + case 6: \ + CASE(6) \ + break; \ + case 7: \ + CASE(7) \ + break; \ + case 10: \ + CASE(10) \ + break; \ + case 11: \ + CASE(11) \ + break; \ + case 12: \ + CASE(12) \ + break; \ + case 13: \ + CASE(13) \ + break; \ + case 16: \ + CASE(16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } namespace infini { -void transpose_kernel(float *input, float *output, int nDims, int size, +void transpose_kernel(int dType, void *input, void *output, int nDims, int size, SmallArray strides, SmallArray outputShape) { int blocksize = block_work_size(); int gridsize = (size + block_work_size() - 1) / block_work_size(); - _transpose_kernel<<>>(input, output, nDims, size, - strides, outputShape); + SWITCH_DTYPE(dType) } } // namespace infini diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index a27d4ac4..bb9691a7 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -2,6 +2,7 @@ #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" #include "cuda/cuda_unary.h" +#include "cuda/cuda_utility.h" namespace infini { @@ -12,6 +13,46 @@ class UnaryCuda : public CudaKernelWithoutConfig { } }; +class CastCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + size_t num = op->getOutput()->size(); + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + if (op->getType() == CastType::Float162Float) { + IT_ASSERT(op->getDType() == DataType::Float16 && + op->getOutDType() == DataType::Float32); + cast_kernel((half *)inputData, (float *)outputData, + num); + } else if (op->getType() == CastType::Float2Float16) { + IT_ASSERT(op->getDType() == DataType::Float32 && + op->getOutDType() == DataType::Float16); + cast_kernel((float *)inputData, (half *)outputData, + num); + } else if (op->getType() == CastType::Float2Int32) { + IT_ASSERT(op->getDType() == DataType::Float32 && + op->getOutDType() == DataType::Int32); + cast_kernel((float *)inputData, + (int32_t *)outputData, num); + } else if (op->getType() == CastType::Float2Int8) { + IT_ASSERT(op->getDType() == DataType::Float32 && + op->getOutDType() == DataType::Int8); + cast_kernel((float *)inputData, (int8_t *)outputData, + num); + } else if (op->getType() == CastType::Int82Float) { + IT_ASSERT(op->getDType() == DataType::Int8 && + op->getOutDType() == DataType::Float32); + cast_kernel((int8_t *)inputData, (float *)outputData, + num); + } else { + IT_ASSERT(false); + } + } +}; + class ActivationCudnn : public CudaKernelWithoutConfig { virtual cudnnActivationMode_t getOpType() const = 0; virtual tuple getAlphBeta() const { return {1.f, 0.f}; } @@ -33,17 +74,17 @@ class ActivationCudnn : public CudaKernelWithoutConfig { while (stride.size() < 4) stride.push_back(1); + auto cudnnDataType = cudnnDataTypeConvert(op->getDType()); + // get inputs checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); - checkCudnnError(cudnnSetTensorNdDescriptor(inputDesc, CUDNN_DATA_FLOAT, - dim.size(), dim.data(), - stride.data())); + checkCudnnError(cudnnSetTensorNdDescriptor( + inputDesc, cudnnDataType, dim.size(), dim.data(), stride.data())); // get outputs checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc)); - checkCudnnError(cudnnSetTensorNdDescriptor(outputDesc, CUDNN_DATA_FLOAT, - dim.size(), dim.data(), - stride.data())); + checkCudnnError(cudnnSetTensorNdDescriptor( + outputDesc, cudnnDataType, dim.size(), dim.data(), stride.data())); // get op descriptor cudnnActivationDescriptor_t activationDesc; @@ -86,16 +127,18 @@ class SoftmaxCudnn : public CudaKernelWithoutConfig { memcpy(dim_array + (4 - dim.size()), dim.data(), dim.size() * sizeof(int)); + auto cudnnDataType = cudnnDataTypeConvert(op->getDType()); + // get inputs checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); checkCudnnError(cudnnSetTensor4dDescriptor( - inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, dim_array[0], + inputDesc, CUDNN_TENSOR_NCHW, cudnnDataType, dim_array[0], dim_array[1], dim_array[2], dim_array[3])); // get outputs checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc)); checkCudnnError(cudnnSetTensor4dDescriptor( - outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, dim_array[0], + outputDesc, CUDNN_TENSOR_NCHW, cudnnDataType, dim_array[0], dim_array[1], dim_array[2], dim_array[3])); auto [alpha, beta] = getAlphBeta(); @@ -130,35 +173,27 @@ class TanhCudnn : public ActivationCudnn { } }; -REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, ReluCudnn, - "Relu_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, SigmoidCudnn, - "Sigmoid_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::HardSigmoid, DataType::Float32, UnaryCuda, - "Hard_Sigmoid_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::HardSwish, DataType::Float32, UnaryCuda, - "Hard_Swish_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn, - "Tanh_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, - "Abs_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda, - "Sqrt_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Gelu, DataType::Float32, UnaryCuda, - "Gelu_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Neg, DataType::Float32, UnaryCuda, - "Neg_CUDA_Float32"); -REGISTER_KERNEL(Device::CUDA, OpType::Erf, DataType::Float32, UnaryCuda, - "Erf_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Relu, ReluCudnn, "Relu_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, SigmoidCudnn, "Sigmoid_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::HardSigmoid, UnaryCuda, + "Hard_Sigmoid_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::HardSwish, UnaryCuda, "Hard_Swish_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Tanh, TanhCudnn, "Tanh_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Abs, UnaryCuda, "Abs_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, UnaryCuda, "Sqrt_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA"); -// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda, -// "Softmax_CUDA_Float32"); -// REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, UnaryCuda, -// "Relu_CUDA_Float32"); -// REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, UnaryCuda, -// "Sigmoid_CUDA_Float32"); -// REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, UnaryCuda, -// "Tanh_CUDA_Float32"); -// REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, -// "Abs_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Cast, CastCuda, "Cast_CUDA"); + +// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda, "Softmax_CUDA"); +// REGISTER_KERNEL(Device::CUDA, OpType::Relu, UnaryCuda, +// "Relu_CUDA"); +// REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, UnaryCuda, +// "Sigmoid_CUDA"); +// REGISTER_KERNEL(Device::CUDA, OpType::Tanh, UnaryCuda, +// "Tanh_CUDA"); +// REGISTER_KERNEL(Device::CUDA, OpType::Abs, UnaryCuda, +// "Abs_CUDA"); }; // namespace infini diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 22e2e423..afd7f02a 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -1,6 +1,8 @@ #include "core/common.h" #include "core/constants.h" #include "cuda/cuda_common.h" +#include "cuda/cuda_unary.h" +#include #include using infini::E_CONSTANT; @@ -8,15 +10,16 @@ constexpr unsigned int num_threads() { return 32 * 4; } constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } -__global__ void _softmax_kernel1(float *input, float *output, size_t n) { +template +__global__ void _softmax_kernel1(T *input, T *output, size_t n) { float sum = 0.0f; for (size_t i = 0; i < n; ++i) { sum += pow(E_CONSTANT, input[i]); } *output = sum; } - -__global__ void _softmax_kernel2(float *input, float *output, size_t n) { +template +__global__ void _softmax_kernel2(T *input, T *output, size_t n) { float sum = *output; size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; @@ -24,32 +27,32 @@ __global__ void _softmax_kernel2(float *input, float *output, size_t n) { output[i] = pow(E_CONSTANT, input[i]) / sum; } } - -__global__ void _relu_kernel(float *input, float *output, size_t n) { +template +__global__ void _relu_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { output[i] = max(input[i], float(0)); } } - -__global__ void _sigmoid_kernel(float *input, float *output, size_t n) { +template +__global__ void _sigmoid_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { output[i] = 1 / (1 + pow(E_CONSTANT, -input[i])); } } - -__global__ void _hard_sigmoid_kernel(float *input, float *output, size_t n) { +template +__global__ void _hard_sigmoid_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { output[i] = max(0.0f, min(1.0f, 0.2f * input[i] + 0.5f)); } } - -__global__ void _hard_swish_kernel(float *input, float *output, size_t n) { +template +__global__ void _hard_swish_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { @@ -57,8 +60,8 @@ __global__ void _hard_swish_kernel(float *input, float *output, size_t n) { input[i] * max(0.f, min(1.f, (1.f / 6.f) * input[i] + 0.5f)); } } - -__global__ void _tanh_kernel(float *input, float *output, size_t n) { +template +__global__ void _tanh_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { @@ -66,8 +69,8 @@ __global__ void _tanh_kernel(float *input, float *output, size_t n) { (pow(E_CONSTANT, input[i]) + pow(E_CONSTANT, -input[i])); } } - -__global__ void _abs_kernel(float *input, float *output, size_t n) { +template +__global__ void _abs_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { @@ -83,7 +86,16 @@ __global__ void _sqrt_kernel(float *input, float *output, size_t n) { } } -__global__ void _gelu_kernel(float *input, float *output, size_t n) { +__global__ void _sqrt_kernel(half *input, half *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { + output[i] = hsqrt(input[i]); + } +} + +template +__global__ void _gelu_kernel(T *input, T *output, size_t n) { int index = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; for (int i = index; i < n; i += stride) { @@ -91,8 +103,8 @@ __global__ void _gelu_kernel(float *input, float *output, size_t n) { output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f))); } } - -__global__ void _erf_kernel(float *input, float *output, size_t n) { +template +__global__ void _erf_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (int i = index; i < n; i += stride) { @@ -109,72 +121,187 @@ __global__ void _neg_kernel(T *input, T *output, size_t n) { } } +template +__global__ void _cast_kernel(INPUT *input, OUTPUT *output, size_t n) { + + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + + if (index < n) { + cub::CastOp _CastOp; + output[index] = _CastOp(input[index]); + } +} + namespace infini { -void softmax_kernel(float *input, float *output, size_t num) { +template void softmax_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _softmax_kernel1<<<1, 1>>>(input, output, num); - _softmax_kernel2<<>>(input, output, num); + _softmax_kernel1<<<1, 1>>>(input, output, num); + _softmax_kernel2<<>>(input, output, num); } -void relu_kernel(float *input, float *output, size_t num) { +template void relu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _relu_kernel<<>>(input, output, num); + _relu_kernel<<>>(input, output, num); } -void sigmoid_kernel(float *input, float *output, size_t num) { +template void sigmoid_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _sigmoid_kernel<<>>(input, output, num); + _sigmoid_kernel<<>>(input, output, num); } -void hard_sigmoid_kernel(float *input, float *output, size_t num) { +template +void hard_sigmoid_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _hard_sigmoid_kernel<<>>(input, output, num); + _hard_sigmoid_kernel<<>>(input, output, num); } -void hard_swish_kernel(float *input, float *output, size_t num) { +template void hard_swish_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _hard_swish_kernel<<>>(input, output, num); + _hard_swish_kernel<<>>(input, output, num); } -void tanh_kernel(float *input, float *output, size_t num) { +template void tanh_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _tanh_kernel<<>>(input, output, num); + _tanh_kernel<<>>(input, output, num); } -void abs_kernel(float *input, float *output, size_t num) { +template void abs_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _abs_kernel<<>>(input, output, num); + _abs_kernel<<>>(input, output, num); } -void sqrt_kernel(float *input, float *output, size_t num) { +template void sqrt_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _sqrt_kernel<<>>(input, output, num); + _sqrt_kernel<<>>((T *)input, (T *)output, num); } -void gelu_kernel(float *input, float *output, size_t num) { + +template void gelu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _gelu_kernel<<>>(input, output, num); + _gelu_kernel<<>>(input, output, num); } -void erf_kernel(float *input, float *output, size_t num) { +template void erf_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _erf_kernel<<>>(input, output, num); + _erf_kernel<<>>(input, output, num); } -void neg_kernel(float *input, float *output, size_t num) { +template void neg_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _neg_kernel<<>>(input, output, num); + _neg_kernel<<>>(input, output, num); } + +void unary_kernel(const Operator &_op) { + auto op = as(_op); + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + size_t num = op->getOutput()->size(); + if (op->getOpType() == OpType::Softmax) { + if (_op->getDType() == DataType::Float32) { + softmax_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Relu) { + if (_op->getDType() == DataType::Float32) { + relu_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Sigmoid) { + if (_op->getDType() == DataType::Float32) { + sigmoid_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::HardSigmoid) { + if (_op->getDType() == DataType::Float32) { + hard_sigmoid_kernel((float *)inputData, (float *)outputData, + num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::HardSwish) { + if (_op->getDType() == DataType::Float32) { + hard_swish_kernel((float *)inputData, (float *)outputData, + num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Tanh) { + if (_op->getDType() == DataType::Float32) { + tanh_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Abs) { + if (_op->getDType() == DataType::Float32) { + abs_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Sqrt) { + if (_op->getDType() == DataType::Float32) { + sqrt_kernel((float *)inputData, (float *)outputData, num); + } else if (_op->getDType() == DataType::Float16) { + sqrt_kernel((half *)inputData, (half *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Gelu) { + if (_op->getDType() == DataType::Float32) { + gelu_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Neg) { + if (_op->getDType() == DataType::Float32) { + neg_kernel((float *)inputData, (float *)outputData, num); + } else if (_op->getDType() == DataType::Float16) { + neg_kernel((half *)inputData, (half *)outputData, num); + } else { + IT_TODO_HALT(); + } + } + + else if (op->getOpType() == OpType::Erf) { + if (_op->getDType() == DataType::Float32) { + erf_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else + IT_TODO_HALT(); +} + +template +void cast_kernel(INPUT *input, OUTPUT *output, size_t num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _cast_kernel<<>>(input, output, num); +} + +template void cast_kernel(float *input, half *output, size_t num); +template void cast_kernel(half *input, float *output, size_t num); +template void cast_kernel(float *input, int32_t *output, + size_t num); +template void cast_kernel(float *input, int8_t *output, + size_t num); +template void cast_kernel(int8_t *input, float *output, + size_t num); + }; // namespace infini diff --git a/src/kernels/cuda/where.cc b/src/kernels/cuda/where.cc index df5e4476..da6ac784 100644 --- a/src/kernels/cuda/where.cc +++ b/src/kernels/cuda/where.cc @@ -36,14 +36,22 @@ class WhereCuda : public CudaKernelWithoutConfig { broadcastShape(opInputYShape, inputYShape, nDims, ySize); broadcastShape(opConditionShape, conditionShape, nDims, cSize); - whereKernel((float *)inputXData, (float *)inputYData, - (uint8_t *)conditionData, (float *)outputData, nDims, - outputsize, inputXShape, inputYShape, conditionShape, - outputShape, xSize, ySize, cSize); + if (op->getDType() == DataType::Float32) { + whereKernel((float *)inputXData, (float *)inputYData, + (uint8_t *)conditionData, (float *)outputData, nDims, + outputsize, inputXShape, inputYShape, conditionShape, + outputShape, xSize, ySize, cSize); + } else if (op->getDType() == DataType::Float16) { + whereKernel((half *)inputXData, (half *)inputYData, + (uint8_t *)conditionData, (half *)outputData, nDims, + outputsize, inputXShape, inputYShape, conditionShape, + outputShape, xSize, ySize, cSize); + } else { + IT_ASSERT(false); + } } }; -REGISTER_KERNEL(Device::CUDA, OpType::Where, DataType::Float32, WhereCuda, - "Where_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Where, WhereCuda, "Where_CUDA"); }; // namespace infini diff --git a/src/kernels/cuda/where.cu b/src/kernels/cuda/where.cu index ac8b514a..e92a5e9f 100644 --- a/src/kernels/cuda/where.cu +++ b/src/kernels/cuda/where.cu @@ -17,13 +17,13 @@ __device__ int inferIndex(infini::SmallArray inputShape, } return inputIdx; } -__global__ void _whereKernel(const float *inputX, const float *inputY, - const uint8_t *condition, float *output, int nDims, - int outputsize, infini::SmallArray inputXShape, - infini::SmallArray inputYShape, - infini::SmallArray conditionShape, - infini::SmallArray outputShape, int xSize, - int ySize, int cSize) { +template +__global__ void +_whereKernel(const T *inputX, const T *inputY, const uint8_t *condition, + T *output, int nDims, int outputsize, + infini::SmallArray inputXShape, infini::SmallArray inputYShape, + infini::SmallArray conditionShape, infini::SmallArray outputShape, + int xSize, int ySize, int cSize) { int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; if (outputIdx < outputsize) { @@ -61,7 +61,31 @@ void whereKernel(const float *inputX, const float *inputY, blocksize = 32; } int gridsize = (outputsize + blocksize - 1) / blocksize; - _whereKernel<<>>( + _whereKernel<<>>( + inputX, inputY, condition, output, nDims, outputsize, inputXShape, + inputYShape, conditionShape, outputShape, xSize, ySize, cSize); +} +void whereKernel(const half *inputX, const half *inputY, + const uint8_t *condition, half *output, int nDims, + int outputsize, SmallArray inputXShape, SmallArray inputYShape, + SmallArray conditionShape, SmallArray outputShape, int xSize, + int ySize, int cSize) { + int blocksize; + if (outputsize > 511) { + blocksize = 1024; + } else if (outputsize > 255) { + blocksize = 512; + } else if (outputsize > 127) { + blocksize = 256; + } else if (outputsize > 63) { + blocksize = 128; + } else if (outputsize > 31) { + blocksize = 64; + } else { + blocksize = 32; + } + int gridsize = (outputsize + blocksize - 1) / blocksize; + _whereKernel<<>>( inputX, inputY, condition, output, nDims, outputsize, inputXShape, inputYShape, conditionShape, outputShape, xSize, ySize, cSize); } diff --git a/src/kernels/intelcpu/batch_norm.cc b/src/kernels/intelcpu/batch_norm.cc index 4583c013..ef024df1 100644 --- a/src/kernels/intelcpu/batch_norm.cc +++ b/src/kernels/intelcpu/batch_norm.cc @@ -7,6 +7,7 @@ class MklBatchNorm : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); float *const srcData = op->getInputs(0)->getRawDataPtr(); @@ -63,6 +64,6 @@ class MklBatchNorm : public MklKernelWithoutConfig { {DNNL_ARG_SHIFT, baisMemory}}); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::BatchNormalization, DataType::Float32, - MklBatchNorm, "BatchNorm_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::BatchNormalization, MklBatchNorm, + "BatchNorm_Mkl"); }; // namespace infini diff --git a/src/kernels/intelcpu/concat.cc b/src/kernels/intelcpu/concat.cc index b4e7b24b..85069e7e 100644 --- a/src/kernels/intelcpu/concat.cc +++ b/src/kernels/intelcpu/concat.cc @@ -7,6 +7,7 @@ class MklConcat : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); // create user memory that describes data layout in the buffers @@ -53,6 +54,5 @@ class MklConcat : public MklKernelWithoutConfig { dnnl::concat(primDesc).execute(context->getStream(), args); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Concat, DataType::Float32, MklConcat, - "Concat_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Concat, MklConcat, "Concat_Mkl"); }; // namespace infini diff --git a/src/kernels/intelcpu/conv.cc b/src/kernels/intelcpu/conv.cc index 77749e09..bd990e08 100644 --- a/src/kernels/intelcpu/conv.cc +++ b/src/kernels/intelcpu/conv.cc @@ -184,6 +184,7 @@ class MklConv : public Kernel { void compute(const Operator &op, const RuntimeObj *context) const override { auto record = make_ref(); + IT_ASSERT(op->getDType() == DataType::Float32); compute(op, record, context); } @@ -233,6 +234,5 @@ class MklConv : public Kernel { return make_ref(ret); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Conv, DataType::Float32, MklConv, - "MklConv_CPU_float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Conv, MklConv, "MklConv_CPU"); } // namespace infini diff --git a/src/kernels/intelcpu/conv_transposed.cc b/src/kernels/intelcpu/conv_transposed.cc index ebf1ad24..b05cc491 100644 --- a/src/kernels/intelcpu/conv_transposed.cc +++ b/src/kernels/intelcpu/conv_transposed.cc @@ -197,6 +197,7 @@ class MklConvTranspose : public Kernel { void compute(const Operator &op, const RuntimeObj *context) const override { auto record = make_ref(); + IT_ASSERT(op->getDType() == DataType::Float32); compute(op, record, context); } @@ -244,7 +245,7 @@ class MklConvTranspose : public Kernel { return make_ref(ret); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::ConvTranspose, DataType::Float32, - MklConvTranspose, "MklConvTrans_CPU_float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::ConvTranspose, MklConvTranspose, + "MklConvTrans_CPU"); } // namespace infini diff --git a/src/kernels/intelcpu/element_wise.cc b/src/kernels/intelcpu/element_wise.cc index 0a27c31e..3cc5c47b 100644 --- a/src/kernels/intelcpu/element_wise.cc +++ b/src/kernels/intelcpu/element_wise.cc @@ -26,6 +26,7 @@ class MklBinary : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -82,6 +83,7 @@ class MklUnary : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const srcData = (op->getInputs(0)->getRawDataPtr()); @@ -113,21 +115,13 @@ class MklUnary : public MklKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Add, DataType::Float32, MklBinary, - "Add_Mkl_Float32"); -REGISTER_KERNEL(Device::INTELCPU, OpType::Sub, DataType::Float32, MklBinary, - "Sub_Mkl_Float32"); -REGISTER_KERNEL(Device::INTELCPU, OpType::Mul, DataType::Float32, MklBinary, - "Mul_Mkl_Float32"); -REGISTER_KERNEL(Device::INTELCPU, OpType::Div, DataType::Float32, MklBinary, - "Div_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Add, MklBinary, "Add_Mkl"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Sub, MklBinary, "Sub_Mkl"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Mul, MklBinary, "Mul_Mkl"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Div, MklBinary, "Div_Mkl"); -REGISTER_KERNEL(Device::INTELCPU, OpType::Relu, DataType::Float32, MklUnary, - "Relu_Mkl_Float32"); -REGISTER_KERNEL(Device::INTELCPU, OpType::Sigmoid, DataType::Float32, MklUnary, - "Sigmoid_Mkl_Float32"); -REGISTER_KERNEL(Device::INTELCPU, OpType::Tanh, DataType::Float32, MklUnary, - "Tanh_Mkl_Float32"); -REGISTER_KERNEL(Device::INTELCPU, OpType::Abs, DataType::Float32, MklUnary, - "Abs_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Relu, MklUnary, "Relu_Mkl"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Sigmoid, MklUnary, "Sigmoid_Mkl"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Tanh, MklUnary, "Tanh_Mkl"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Abs, MklUnary, "Abs_Mkl"); } // namespace infini diff --git a/src/kernels/intelcpu/extend.cc b/src/kernels/intelcpu/extend.cc index dff2ebc1..7e2ce225 100644 --- a/src/kernels/intelcpu/extend.cc +++ b/src/kernels/intelcpu/extend.cc @@ -10,6 +10,7 @@ class MklExtend : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto inData = op->getInputs(0)->getRawDataPtr(); auto outData = op->getOutput(0)->getRawDataPtr(); int iSize = op->getInputs(0)->size(); @@ -40,6 +41,5 @@ class MklExtend : public MklKernelWithoutConfig { sycl::free(outDevice, q); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Extend, DataType::Float32, MklExtend, - "Extend_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Extend, MklExtend, "Extend_Mkl"); }; // namespace infini diff --git a/src/kernels/intelcpu/gather.cc b/src/kernels/intelcpu/gather.cc index 61549ccb..61a6fb45 100644 --- a/src/kernels/intelcpu/gather.cc +++ b/src/kernels/intelcpu/gather.cc @@ -10,6 +10,7 @@ class MklGather : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto in = op->getInputs(0); auto index = op->getInputs(1); auto out = op->getOutput(); @@ -81,6 +82,5 @@ class MklGather : public MklKernelWithoutConfig { sycl::free(indexDevice, q); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Gather, DataType::Float32, MklGather, - "Gather_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Gather, MklGather, "Gather_Mkl"); }; // namespace infini diff --git a/src/kernels/intelcpu/matmul.cc b/src/kernels/intelcpu/matmul.cc index 61cf5c94..811c57ba 100644 --- a/src/kernels/intelcpu/matmul.cc +++ b/src/kernels/intelcpu/matmul.cc @@ -7,6 +7,7 @@ template class MklMatmul : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet."); const T *A = op->getInputs(0)->getRawDataPtr(); const T *B = op->getInputs(1)->getRawDataPtr(); @@ -31,7 +32,7 @@ template class MklMatmul : public CpuKernelWithoutConfig { } }; -/*REGISTER_KERNEL(Device::INTELCPU, OpType::Matmul, DataType::Float32, +/*REGISTER_KERNEL(Device::INTELCPU, OpType::Matmul, MklMatmul, "MklMatmul_CPU_float32");*/ } // namespace infini diff --git a/src/kernels/intelcpu/matmul_dpcpp.cc b/src/kernels/intelcpu/matmul_dpcpp.cc index 8fdddfe2..92db14df 100644 --- a/src/kernels/intelcpu/matmul_dpcpp.cc +++ b/src/kernels/intelcpu/matmul_dpcpp.cc @@ -10,6 +10,7 @@ template class MklDpcppMatmul : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet."); const T *A = op->getInputs(0)->getRawDataPtr(); const T *B = op->getInputs(1)->getRawDataPtr(); @@ -69,7 +70,7 @@ template class MklDpcppMatmul : public CpuKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::MatMul, DataType::Float32, - MklDpcppMatmul, "MklDpcppMatmul_CPU_float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::MatMul, MklDpcppMatmul, + "MklDpcppMatmul_CPU"); } // namespace infini diff --git a/src/kernels/intelcpu/pad.cc b/src/kernels/intelcpu/pad.cc index 8f52e7f6..047d5c80 100644 --- a/src/kernels/intelcpu/pad.cc +++ b/src/kernels/intelcpu/pad.cc @@ -7,6 +7,7 @@ class MklPad : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); std::vector dims; @@ -53,6 +54,5 @@ class MklPad : public MklKernelWithoutConfig { {{DNNL_ARG_FROM, srcMemory}, {DNNL_ARG_TO, mem}}); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Pad, DataType::Float32, MklPad, - "Pad_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Pad, MklPad, "Pad_Mkl"); } // namespace infini diff --git a/src/kernels/intelcpu/pooling.cc b/src/kernels/intelcpu/pooling.cc index d3c9e44d..23e9adee 100644 --- a/src/kernels/intelcpu/pooling.cc +++ b/src/kernels/intelcpu/pooling.cc @@ -9,6 +9,7 @@ class MklPooling : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); float *const srcData = op->getInputs(0)->getRawDataPtr(); @@ -77,8 +78,7 @@ class MklMaxPool : public MklPooling { } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::AveragePool, DataType::Float32, - MklAvgPool, "AvgPool_Mkl_Float32"); -REGISTER_KERNEL(Device::INTELCPU, OpType::MaxPool, DataType::Float32, - MklMaxPool, "MaxPool_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::AveragePool, MklAvgPool, + "AvgPool_Mkl"); +REGISTER_KERNEL(Device::INTELCPU, OpType::MaxPool, MklMaxPool, "MaxPool_Mkl"); } // namespace infini diff --git a/src/kernels/intelcpu/pow.cc b/src/kernels/intelcpu/pow.cc index 166d0a75..493c3148 100644 --- a/src/kernels/intelcpu/pow.cc +++ b/src/kernels/intelcpu/pow.cc @@ -11,6 +11,7 @@ class MklPow : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto in0Data = op->getInputs(0)->getRawDataPtr(); auto in1Data = op->getInputs(1)->getRawDataPtr(); auto outData = op->getOutput(0)->getRawDataPtr(); @@ -37,7 +38,6 @@ class MklPow : public MklKernelWithoutConfig { sycl::free(outDevice, q); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Pow, DataType::Float32, MklPow, - "Pow_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Pow, MklPow, "Pow_Mkl"); }; // namespace infini diff --git a/src/kernels/intelcpu/reduce.cc b/src/kernels/intelcpu/reduce.cc index 6670229e..a63ec014 100644 --- a/src/kernels/intelcpu/reduce.cc +++ b/src/kernels/intelcpu/reduce.cc @@ -1,6 +1,6 @@ +#include "operators/reduce.h" #include "intelcpu/mkl_kernel_without_config.h" #include "intelcpu/mkl_runtime.h" -#include "operators/reduce_mean.h" namespace infini { class MklReduce : public MklKernelWithoutConfig { @@ -11,6 +11,7 @@ class MklReduce : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); float *const srcData = op->getInputs(0)->getRawDataPtr(); @@ -64,6 +65,6 @@ class MklReduce : public MklKernelWithoutConfig { {{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}}); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::ReduceMean, DataType::Float32, - MklReduce, "ReduceMean_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::ReduceMean, MklReduce, + "ReduceMean_Mkl"); }; // namespace infini diff --git a/src/kernels/intelcpu/reshape.cc b/src/kernels/intelcpu/reshape.cc index 2a17b881..a1432ac9 100644 --- a/src/kernels/intelcpu/reshape.cc +++ b/src/kernels/intelcpu/reshape.cc @@ -6,7 +6,7 @@ namespace infini { class MklReshape : public MklKernelWithoutConfig { void compute(const Operator &op, const RuntimeObj *_context) const override { - + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); std::vector dims; @@ -41,10 +41,7 @@ class MklReshape : public MklKernelWithoutConfig { {{DNNL_ARG_FROM, reshapeMemory}, {DNNL_ARG_TO, output}}); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Reshape, DataType::Float32, - MklReshape, "Reshape_Mkl_Float32"); -REGISTER_KERNEL(Device::INTELCPU, OpType::Identity, DataType::Float32, - MklReshape, "Identify_Mkl_Float32"); -REGISTER_KERNEL(Device::INTELCPU, OpType::Flatten, DataType::Float32, - MklReshape, "Flatten_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Reshape, MklReshape, "Reshape_Mkl"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Identity, MklReshape, "Identify_Mkl"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Flatten, MklReshape, "Flatten_Mkl"); }; // namespace infini diff --git a/src/kernels/intelcpu/resize.cc b/src/kernels/intelcpu/resize.cc index f9a85634..524db879 100644 --- a/src/kernels/intelcpu/resize.cc +++ b/src/kernels/intelcpu/resize.cc @@ -24,6 +24,7 @@ class MklResize : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); // only support default coordinate transmode?? if (op->getCoordinateTransMode() != @@ -75,6 +76,5 @@ class MklResize : public MklKernelWithoutConfig { {{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}}); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Resize, DataType::Float32, MklResize, - "Resize_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Resize, MklResize, "Resize_Mkl"); }; // namespace infini diff --git a/src/kernels/intelcpu/slice.cc b/src/kernels/intelcpu/slice.cc index a5715ced..42e45d5b 100644 --- a/src/kernels/intelcpu/slice.cc +++ b/src/kernels/intelcpu/slice.cc @@ -7,6 +7,7 @@ class MklSlice : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); std::vector dims; @@ -41,6 +42,5 @@ class MklSlice : public MklKernelWithoutConfig { {{DNNL_ARG_FROM, sliceMemory}, {DNNL_ARG_TO, output}}); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Slice, DataType::Float32, MklSlice, - "Slice_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Slice, MklSlice, "Slice_Mkl"); } // namespace infini diff --git a/src/kernels/intelcpu/softmax.cc b/src/kernels/intelcpu/softmax.cc index 32c58a94..fe88cefa 100644 --- a/src/kernels/intelcpu/softmax.cc +++ b/src/kernels/intelcpu/softmax.cc @@ -7,6 +7,7 @@ class MklSoftmax : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); float *const srcData = op->getInputs(0)->getRawDataPtr(); @@ -38,6 +39,5 @@ class MklSoftmax : public MklKernelWithoutConfig { {{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}}); } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Softmax, DataType::Float32, - MklSoftmax, "Softmax_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Softmax, MklSoftmax, "Softmax_Mkl"); }; // namespace infini diff --git a/src/kernels/intelcpu/split.cc b/src/kernels/intelcpu/split.cc index df859083..37d28360 100644 --- a/src/kernels/intelcpu/split.cc +++ b/src/kernels/intelcpu/split.cc @@ -7,6 +7,7 @@ class MklSplit : public MklKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); std::vector dims; @@ -49,6 +50,5 @@ class MklSplit : public MklKernelWithoutConfig { } } }; -REGISTER_KERNEL(Device::INTELCPU, OpType::Split, DataType::Float32, MklSplit, - "Split_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Split, MklSplit, "Split_Mkl"); }; // namespace infini diff --git a/src/kernels/kunlun/batch_norm.cc b/src/kernels/kunlun/batch_norm.cc index d1c8c3b4..d0e1c9b2 100644 --- a/src/kernels/kunlun/batch_norm.cc +++ b/src/kernels/kunlun/batch_norm.cc @@ -7,6 +7,7 @@ class BatchNormXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const input = (op->getInputs(0)->getRawDataPtr()); @@ -35,7 +36,7 @@ class BatchNormXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::BatchNormalization, DataType::Float32, - BatchNormXdnn, "BatchNorm_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::BatchNormalization, BatchNormXdnn, + "BatchNorm_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/cast.cc b/src/kernels/kunlun/cast.cc index 443cc259..0bd7e4e8 100644 --- a/src/kernels/kunlun/cast.cc +++ b/src/kernels/kunlun/cast.cc @@ -93,6 +93,5 @@ class CastXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Cast, DataType::Float32, CastXdnn, - "Cast_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Cast, CastXdnn, "Cast_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/concat.cc b/src/kernels/kunlun/concat.cc index 35777cae..f7ba2a2d 100644 --- a/src/kernels/kunlun/concat.cc +++ b/src/kernels/kunlun/concat.cc @@ -7,6 +7,7 @@ class ConcatXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); int axis = op->getDim(); int num = op->numInputs(); @@ -32,6 +33,6 @@ class ConcatXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Concat, DataType::Float32, ConcatXdnn, - "Concat_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Concat, ConcatXdnn, + "Concat_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/conv.cc b/src/kernels/kunlun/conv.cc index 80cc37c7..45f054b1 100644 --- a/src/kernels/kunlun/conv.cc +++ b/src/kernels/kunlun/conv.cc @@ -7,6 +7,7 @@ class ConvXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -32,6 +33,5 @@ class ConvXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Conv, DataType::Float32, ConvXdnn, - "Conv_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Conv, ConvXdnn, "Conv_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/conv_trans.cc b/src/kernels/kunlun/conv_trans.cc index 841955a6..8219d829 100644 --- a/src/kernels/kunlun/conv_trans.cc +++ b/src/kernels/kunlun/conv_trans.cc @@ -7,6 +7,7 @@ class ConvTransXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -46,9 +47,9 @@ class ConvTransXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::ConvTranspose, DataType::Float32, - ConvTransXdnn, "ConvTrans_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::ConvTransNHWC, DataType::Float32, - ConvTransXdnn, "ConvTranposedNHWC_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::ConvTranspose, ConvTransXdnn, + "ConvTrans_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::ConvTransNHWC, ConvTransXdnn, + "ConvTranposedNHWC_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/element_wise.cc b/src/kernels/kunlun/element_wise.cc index 3370eb1a..5a9754f5 100644 --- a/src/kernels/kunlun/element_wise.cc +++ b/src/kernels/kunlun/element_wise.cc @@ -7,6 +7,7 @@ class AddXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -33,6 +34,7 @@ class SubXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -59,6 +61,7 @@ class MulXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -85,6 +88,7 @@ class DivXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -111,6 +115,7 @@ class PowXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -138,6 +143,7 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -164,6 +170,7 @@ class MinXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -190,6 +197,7 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -220,6 +228,7 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -250,6 +259,7 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -280,6 +290,7 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -310,6 +321,7 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -340,6 +352,7 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -366,6 +379,7 @@ class MSELossXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -386,6 +400,7 @@ class AndXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -416,6 +431,7 @@ class OrXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -446,6 +462,7 @@ class XorXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -476,6 +493,7 @@ class NotXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -493,40 +511,28 @@ class NotXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Add, DataType::Float32, AddXdnn, - "Add_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Sub, DataType::Float32, SubXdnn, - "Sub_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Mul, DataType::Float32, MulXdnn, - "Mul_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Div, DataType::Float32, DivXdnn, - "Div_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Pow, DataType::Float32, PowXdnn, - "Pow_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Max, DataType::Float32, MaxXdnn, - "Max_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Min, DataType::Float32, MinXdnn, - "Min_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Equal, DataType::Float32, EqualXdnn, - "Equal_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::GreaterOrEqual, DataType::Float32, - GreaterEqualXdnn, "GreaterEqual_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Greater, DataType::Float32, - GreaterThanXdnn, "GreaterThan_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::LessOrEqual, DataType::Float32, - LessEqualXdnn, "LessEqual_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Less, DataType::Float32, LessThanXdnn, - "LessThan_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::FloorDiv, DataType::Float32, - FloorDivXdnn, "FloorDiv_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::MSELoss, DataType::Float32, MSELossXdnn, - "MSELoss_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::And, DataType::Float32, AndXdnn, - "And_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Or, DataType::Float32, OrXdnn, - "Or_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Xor, DataType::Float32, XorXdnn, - "Xor_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Not, DataType::Float32, NotXdnn, - "Not_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Add, AddXdnn, "Add_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sub, SubXdnn, "Sub_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Mul, MulXdnn, "Mul_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Div, DivXdnn, "Div_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Pow, PowXdnn, "Pow_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Max, MaxXdnn, "Max_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Min, MinXdnn, "Min_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Equal, EqualXdnn, "Equal_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::GreaterOrEqual, GreaterEqualXdnn, + "GreaterEqual_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Greater, GreaterThanXdnn, + "GreaterThan_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::LessOrEqual, LessEqualXdnn, + "LessEqual_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Less, LessThanXdnn, + "LessThan_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::FloorDiv, FloorDivXdnn, + "FloorDiv_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::MSELoss, MSELossXdnn, + "MSELoss_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::And, AndXdnn, "And_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Or, OrXdnn, "Or_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Xor, XorXdnn, "Xor_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Not, NotXdnn, "Not_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/gather.cc b/src/kernels/kunlun/gather.cc index f94d24fa..75fd2365 100644 --- a/src/kernels/kunlun/gather.cc +++ b/src/kernels/kunlun/gather.cc @@ -7,6 +7,7 @@ class GatherXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -24,6 +25,6 @@ class GatherXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Gather, DataType::Float32, GatherXdnn, - "Gather_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Gather, GatherXdnn, + "Gather_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/matmul.cc b/src/kernels/kunlun/matmul.cc index 8506e812..f70394a9 100644 --- a/src/kernels/kunlun/matmul.cc +++ b/src/kernels/kunlun/matmul.cc @@ -7,6 +7,7 @@ class MatmulXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); @@ -28,6 +29,6 @@ class MatmulXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::MatMul, DataType::Float32, MatmulXdnn, - "Matmul_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::MatMul, MatmulXdnn, + "Matmul_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/pad.cc b/src/kernels/kunlun/pad.cc index 2ae93d99..de063828 100644 --- a/src/kernels/kunlun/pad.cc +++ b/src/kernels/kunlun/pad.cc @@ -7,6 +7,7 @@ class PadXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -31,7 +32,6 @@ class PadXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Pad, DataType::Float32, PadXdnn, - "Pad_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Pad, PadXdnn, "Pad_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/pooling.cc b/src/kernels/kunlun/pooling.cc index 27b8458a..bc49e31c 100644 --- a/src/kernels/kunlun/pooling.cc +++ b/src/kernels/kunlun/pooling.cc @@ -7,6 +7,7 @@ class AvgPooling : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -30,6 +31,7 @@ class MaxPooling : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -55,8 +57,7 @@ class MaxPooling : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::MaxPool, DataType::Float32, MaxPooling, - "MaxPool_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::AveragePool, DataType::Float32, - AvgPooling, "AvgPool_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::MaxPool, MaxPooling, "MaxPool_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::AveragePool, AvgPooling, + "AvgPool_xdnn"); }; // namespace infini diff --git a/src/kernels/kunlun/reduce_mean.cc b/src/kernels/kunlun/reduce_mean.cc index c7cf19ac..928d42c8 100644 --- a/src/kernels/kunlun/reduce_mean.cc +++ b/src/kernels/kunlun/reduce_mean.cc @@ -7,6 +7,7 @@ class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -25,6 +26,6 @@ class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceMean, DataType::Float32, - ReduceMeanXdnn, "ReduceMean_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceMean, ReduceMeanXdnn, + "ReduceMean_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/select.cc b/src/kernels/kunlun/select.cc index d6318e46..7cdfd8bf 100644 --- a/src/kernels/kunlun/select.cc +++ b/src/kernels/kunlun/select.cc @@ -7,6 +7,7 @@ class WhereXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -27,6 +28,5 @@ class WhereXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Where, DataType::Float32, WhereXdnn, - "Where_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Where, WhereXdnn, "Where_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/softmax.cc b/src/kernels/kunlun/softmax.cc index 56374766..552b6c21 100644 --- a/src/kernels/kunlun/softmax.cc +++ b/src/kernels/kunlun/softmax.cc @@ -7,6 +7,7 @@ class SoftmaxXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); auto dim = op->getInputs(0)->getDims(); auto axis = op->getAxis(); @@ -21,6 +22,6 @@ class SoftmaxXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Softmax, DataType::Float32, SoftmaxXdnn, - "Softmax_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Softmax, SoftmaxXdnn, + "Softmax_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/split.cc b/src/kernels/kunlun/split.cc index 46276c85..f76f86ff 100644 --- a/src/kernels/kunlun/split.cc +++ b/src/kernels/kunlun/split.cc @@ -7,6 +7,7 @@ class SplitXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); int axis = op->getDim(); int num = op->numOutputs(); @@ -33,6 +34,5 @@ class SplitXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Split, DataType::Float32, SplitXdnn, - "Split_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Split, SplitXdnn, "Split_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/transpose.cc b/src/kernels/kunlun/transpose.cc index 817c32e2..7a89480e 100644 --- a/src/kernels/kunlun/transpose.cc +++ b/src/kernels/kunlun/transpose.cc @@ -7,6 +7,7 @@ class TransposeXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -31,6 +32,7 @@ class DepthToSpaceXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -52,8 +54,8 @@ class DepthToSpaceXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Transpose, DataType::Float32, - TransposeXdnn, "Transpose_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::DepthToSpace, DataType::Float32, - DepthToSpaceXdnn, "DepthToSpace_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Transpose, TransposeXdnn, + "Transpose_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::DepthToSpace, DepthToSpaceXdnn, + "DepthToSpace_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/unary.cc b/src/kernels/kunlun/unary.cc index c24fddaf..3b444d3b 100644 --- a/src/kernels/kunlun/unary.cc +++ b/src/kernels/kunlun/unary.cc @@ -7,6 +7,7 @@ class ReluXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -24,6 +25,7 @@ class SigmoidXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -41,6 +43,7 @@ class TanhXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -58,6 +61,7 @@ class SquareXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -75,6 +79,7 @@ class SqrtXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -92,6 +97,7 @@ class RsqrtXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -109,6 +115,7 @@ class ExpXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -126,6 +133,7 @@ class CeilXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -143,6 +151,7 @@ class ClipXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -163,6 +172,7 @@ class FloorXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -180,6 +190,7 @@ class NegXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -196,6 +207,7 @@ class NegXdnn : public KUNLUNKernelWithoutConfig { class CopyXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &op, const RuntimeObj *_context) const override { + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -213,6 +225,7 @@ class ReciprocalXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -230,6 +243,7 @@ class AbsXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -247,6 +261,7 @@ class ATanXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -264,6 +279,7 @@ class LogXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -316,6 +332,7 @@ class CosXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -332,6 +349,7 @@ class SinXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -348,6 +366,7 @@ class TanXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -364,6 +383,7 @@ class SinhXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -380,6 +400,7 @@ class CoshXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -396,6 +417,7 @@ class ErfXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -412,6 +434,7 @@ class ACosXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -428,6 +451,7 @@ class ACoshXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -444,6 +468,7 @@ class ASinXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -460,6 +485,7 @@ class ASinhXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -476,6 +502,7 @@ class ATanhXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -488,63 +515,38 @@ class ATanhXdnn : public KUNLUNKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, DataType::Float32, ReluXdnn, - "Relu_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, DataType::Float32, SigmoidXdnn, - "Sigmoid_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, DataType::Float32, TanhXdnn, - "Tanh_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Square, DataType::Float32, SquareXdnn, - "Square_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Sqrt, DataType::Float32, SqrtXdnn, - "Sqrt_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Rsqrt, DataType::Float32, RsqrtXdnn, - "Rsqrt_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Exp, DataType::Float32, ExpXdnn, - "Exp_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Ceil, DataType::Float32, CeilXdnn, - "Ceil_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Clip, DataType::Float32, ClipXdnn, - "Clip_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Floor, DataType::Float32, FloorXdnn, - "Floor_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Neg, DataType::Float32, NegXdnn, - "Neg_xdnn_KUNLUN_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Reciprocal, DataType::Float32, - ReciprocalXdnn, "Reciprocal_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, ReluXdnn, "Relu_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, SigmoidXdnn, + "Sigmoid_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, TanhXdnn, "Tanh_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Square, SquareXdnn, + "Square_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sqrt, SqrtXdnn, "Sqrt_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Rsqrt, RsqrtXdnn, "Rsqrt_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Exp, ExpXdnn, "Exp_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Ceil, CeilXdnn, "Ceil_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Clip, ClipXdnn, "Clip_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Floor, FloorXdnn, "Floor_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Neg, NegXdnn, "Neg_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Reciprocal, ReciprocalXdnn, + "Reciprocal_xdnn_KUNLUN"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Reshape, DataType::Float32, CopyXdnn, - "Reshape_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Flatten, DataType::Float32, CopyXdnn, - "Flatten_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Identity, DataType::Float32, CopyXdnn, - "Identity_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Abs, DataType::Float32, AbsXdnn, - "Abs_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Atan, DataType::Float32, ATanXdnn, - "Atan_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Log, DataType::Float32, LogXdnn, - "Log_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Cos, DataType::Float32, CosXdnn, - "Cos_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Sin, DataType::Float32, SinXdnn, - "Sin_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Tan, DataType::Float32, TanXdnn, - "Tan_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Sinh, DataType::Float32, SinhXdnn, - "Sinh_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Cosh, DataType::Float32, CoshXdnn, - "Cosh_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Erf, DataType::Float32, ErfXdnn, - "Erf_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Acos, DataType::Float32, ACosXdnn, - "ACos_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Acosh, DataType::Float32, ACoshXdnn, - "ACosh_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Asin, DataType::Float32, ASinXdnn, - "ASin_xdnn_Float32"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Asinh, DataType::Float32, ASinhXdnn, +REGISTER_KERNEL(Device::KUNLUN, OpType::Reshape, CopyXdnn, "Reshape_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Flatten, CopyXdnn, "Flatten_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Identity, CopyXdnn, "Identity_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Abs, AbsXdnn, "Abs_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Atan, ATanXdnn, "Atan_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Log, LogXdnn, "Log_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Cos, CosXdnn, "Cos_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sin, SinXdnn, "Sin_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Tan, TanXdnn, "Tan_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sinh, SinhXdnn, "Sinh_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Cosh, CoshXdnn, "Cosh_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Erf, ErfXdnn, "Erf_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Acos, ACosXdnn, "ACos_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Acosh, ACoshXdnn, "ACosh_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Asin, ASinXdnn, "ASin_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Asinh, ASinhXdnn, "ASinh_xdnn_Float3 2"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Atanh, DataType::Float32, ATanhXdnn, - "ATanh_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Atanh, ATanhXdnn, "ATanh_xdnn"); }; // namespace infini diff --git a/src/kunlun/kunlun_runtime.cc b/src/kunlun/kunlun_runtime.cc index b40e772f..b614ac9c 100644 --- a/src/kunlun/kunlun_runtime.cc +++ b/src/kunlun/kunlun_runtime.cc @@ -13,8 +13,7 @@ void KUNLUNRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, std::map opCnt; for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = - KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/operators/layer_norm.cc b/src/operators/layer_norm.cc index 68649215..5109c79b 100644 --- a/src/operators/layer_norm.cc +++ b/src/operators/layer_norm.cc @@ -27,10 +27,7 @@ optional> LayerNormObj::inferShape(const TensorVec &inputs) { vector LayerNormObj::inferDataType(const TensorVec &inputs) const { IT_ASSERT(inputs.size() == 2 || inputs.size() == 3); - IT_ASSERT(inputs[1]->getDType() == DataType::Float32); - if (inputs.size() == 3) { - IT_ASSERT(inputs[2]->getDType() == DataType::Float32); - } + return {inputs[0]->getDType()}; } diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index 6687a8fd..b191fb33 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -112,7 +112,6 @@ std::string device_to_str(Device device) { std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) { std::string deviceStr = device_to_str(std::get<0>(kernelAttrs)); std::string opStr = OpType(std::get<1>(kernelAttrs)).toString(); - std::string datatypeStr = std::get<2>(kernelAttrs).toString(); - return deviceStr + ", " + opStr + ", " + datatypeStr; + return deviceStr + ", " + opStr; } } // namespace infini diff --git a/test/kernels/cuda/test_cuda_concat.cc b/test/kernels/cuda/test_cuda_concat.cc index 12e18a56..a17ba24b 100644 --- a/test/kernels/cuda/test_cuda_concat.cc +++ b/test/kernels/cuda/test_cuda_concat.cc @@ -187,4 +187,42 @@ TEST(ConcatToIdentity, Cuda) { EXPECT_TRUE( oCpu->equalData(vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})); } +//---------- +TEST(ConcatFp16, CudaHigh) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto t1 = gCpu->addTensor({2, 2, 3, 1, 2}, DataType::Float16); + auto t2 = gCpu->addTensor({2, 2, 1, 1, 2}, DataType::Float16); + auto t3 = gCpu->addTensor({2, 2, 2, 1, 2}, DataType::Float16); + gCpu->dataMalloc(); + t1->setData(ValGenerator<2>()); + t2->setData(ValGenerator<1>()); + t3->setData(ValGenerator<4>()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto t1Gpu = gCuda->cloneTensor(t1); + auto t2Gpu = gCuda->cloneTensor(t2); + auto t3Gpu = gCuda->cloneTensor(t3); + + auto op = + gCuda->addOp(TensorVec{t1Gpu, t2Gpu, t3Gpu}, nullptr, 2); + gCuda->dataMalloc(); + t1Gpu->setData(ValGenerator<2>()); + t2Gpu->setData(ValGenerator<1>()); + t3Gpu->setData(ValGenerator<4>()); + + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData(vector{ + 2., 2., 2., 2., 2., 2., 1., 1., 4., 4., 4., 4., 2., 2., 2., 2., + 2., 2., 1., 1., 4., 4., 4., 4., 2., 2., 2., 2., 2., 2., 1., 1., + 4., 4., 4., 4., 2., 2., 2., 2., 2., 2., 1., 1., 4., 4., 4., 4.})); +} + } // namespace infini diff --git a/test/kernels/cuda/test_cuda_conv_transposed_2d.cc b/test/kernels/cuda/test_cuda_conv_transposed_2d.cc index 0c8899e4..3825a06e 100644 --- a/test/kernels/cuda/test_cuda_conv_transposed_2d.cc +++ b/test/kernels/cuda/test_cuda_conv_transposed_2d.cc @@ -160,8 +160,8 @@ TEST(cuDNN_ConvTransposed, tune) { bool tune = true; cuda->run(gCuda, tune); // check record - auto kernelAttrs = KernelAttrs{Device::CUDA, conv->getOpType().underlying(), - DataType::Float32}; + auto kernelAttrs = + KernelAttrs{Device::CUDA, conv->getOpType().underlying()}; auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; std::optional perfData = PerfEngine::getInstance().getPerfData(perfKey); diff --git a/test/kernels/cuda/test_cuda_layernorm.cc b/test/kernels/cuda/test_cuda_layernorm.cc index 18b8c4df..e2af489e 100644 --- a/test/kernels/cuda/test_cuda_layernorm.cc +++ b/test/kernels/cuda/test_cuda_layernorm.cc @@ -8,7 +8,7 @@ namespace infini { -void test_layernorm( +void test_layernormFp32( const Shape &inputShape, const vector &inputData, const Shape &scaleShape, const vector &scaleData, float eps, int axis, int stash_type, const vector &ExpectData, @@ -77,9 +77,78 @@ void test_layernorm( EXPECT_TRUE(oCpu->equalData(ExpectData)); } } +void test_layernormFp16( + const Shape &inputShape, + const std::function &generator, + const Shape &scaleShape, float eps, int axis, int stash_type, + const vector &ExpectData, + const std::optional &bShape = std::nullopt) { -TEST(CUDA_Layernorm, run) { - test_layernorm( + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + if (bShape.has_value()) { + Shape biasShape = *bShape; + + auto bias = gCpu->addTensor(biasShape, DataType::Float16); + auto input = gCpu->addTensor(inputShape, DataType::Float16); + auto scale = gCpu->addTensor(scaleShape, DataType::Float16); + gCpu->dataMalloc(); + bias->setData(generator); + // bias->printData(); + input->setData(generator); + scale->setData(generator); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + auto biasGpu = gCuda->cloneTensor(bias); + auto inputGpu = gCuda->cloneTensor(input); + auto scaleGpu = gCuda->cloneTensor(scale); + // gCpu->cloneTensor(biasGpu)->printData(); + auto op = + gCuda->addOp(inputGpu, scaleGpu, nullptr, biasGpu, + eps, axis, stash_type); // LayernormObj + gCuda->dataMalloc(); + biasGpu->setData(generator); + // gCpu->cloneTensor(biasGpu)->printData(); + inputGpu->setData(generator); + scaleGpu->setData(generator); + cudaRuntime->run(gCuda); + + auto oCpu = + gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); + } else { + + auto input = gCpu->addTensor(inputShape, DataType::Float16); + auto scale = gCpu->addTensor(scaleShape, DataType::Float16); + gCpu->dataMalloc(); + + input->setData(generator); + scale->setData(generator); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputGpu = gCuda->cloneTensor(input); + auto scaleGpu = gCuda->cloneTensor(scale); + auto op = + gCuda->addOp(inputGpu, scaleGpu, nullptr, nullptr, + eps, axis, stash_type); // LayernormObj + gCuda->dataMalloc(); + + inputGpu->setData(generator); + scaleGpu->setData(generator); + cudaRuntime->run(gCuda); + + auto oCpu = + gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); + } +} + +TEST(CUDA_LayernormFp32, run) { + test_layernormFp32( Shape{2, 3, 2, 3}, vector{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., @@ -94,7 +163,7 @@ TEST(CUDA_Layernorm, run) { -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678}, Shape{3}, vector{0, 0, 0}); - test_layernorm( + test_layernormFp32( Shape{2, 3, 2, 3}, vector{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., @@ -109,7 +178,7 @@ TEST(CUDA_Layernorm, run) { -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679}, Shape{3}, vector{0.3, 0.2, 0.5}); - test_layernorm( + test_layernormFp32( Shape{2, 3, 2, 3}, vector{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., @@ -124,7 +193,7 @@ TEST(CUDA_Layernorm, run) { -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207}, Shape{3}, vector{0.3, 0.2, 0.5}); - test_layernorm( + test_layernormFp32( Shape{2, 3, 2, 3}, vector{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., @@ -141,6 +210,15 @@ TEST(CUDA_Layernorm, run) { 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678}); +} // python output +TEST(CUDA_LayernormFp16, run) { + test_layernormFp16(Shape{2, 3, 2, 3}, ValGenerator<2>(), Shape{3}, 1e-5, 3, + 1, vector{2., 2., 2., 2., 2., 2., 2., 2., 2., + 2., 2., 2., 2., 2., 2., 2., 2., 2., + 2., 2., 2., 2., 2., 2., 2., 2., 2., + 2., 2., 2., 2., 2., 2., 2., 2., 2.}, + Shape{3}); + } // python output } // namespace infini diff --git a/test/kernels/cuda/test_cuda_softmax.cc b/test/kernels/cuda/test_cuda_softmax.cc index 9ce9705d..be73554d 100644 --- a/test/kernels/cuda/test_cuda_softmax.cc +++ b/test/kernels/cuda/test_cuda_softmax.cc @@ -8,130 +8,127 @@ #include namespace infini { -TEST(cuDNN_Softmax, run_axis1) { - // Runtime - Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); +void test_softmaxFp32(const Shape &inputShape, const vector &inputData, + int axis, const vector &ExpectData) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor(inputShape, DataType::Float32); + + gCpu->dataMalloc(); + + input->copyin(inputData); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); - // Build input data on CPU - Tensor inputCpu = - make_ref(Shape{2, 4}, DataType::Float32, cpuRuntime); + auto inputGpu = gCuda->cloneTensor(input); - // GPU - Graph cudaGraph = make_ref(cudaRuntime); - auto inputGpu = cudaGraph->cloneTensor(inputCpu); - auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 1); - cudaGraph->dataMalloc(); - inputGpu->copyin(vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); - cudaRuntime->run(cudaGraph); - auto outputGpu = gpuOp->getOutput(); - auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); - cudaPrintTensor(outputGpu); - // Check - EXPECT_TRUE(outputGpu2Cpu->equalData( - vector{0.032058604, 0.08714432, 0.23688284, 0.6439143, - 0.032058604, 0.08714432, 0.23688284, 0.6439143})); + auto op = gCuda->addOp(inputGpu, nullptr, axis); + gCuda->dataMalloc(); + + inputGpu->copyin(inputData); + + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); } +void test_softmaxFp16( + const Shape &inputShape, + const std::function &generator, int axis, + const vector &ExpectData) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor(inputShape, DataType::Float16); + + gCpu->dataMalloc(); + + input->setData(generator); -TEST(cuDNN_Softmax, run_axis0) { - // Runtime - Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); - // Build input data on CPU - Tensor inputCpu = - make_ref(Shape{2, 4}, DataType::Float32, cpuRuntime); + auto inputGpu = gCuda->cloneTensor(input); - // GPU - Graph cudaGraph = make_ref(cudaRuntime); - auto inputGpu = cudaGraph->cloneTensor(inputCpu); - auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 0); - cudaGraph->dataMalloc(); - inputGpu->copyin(vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); - cudaRuntime->run(cudaGraph); - auto outputGpu = gpuOp->getOutput(); - auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); - cudaPrintTensor(outputGpu); - // Check - EXPECT_TRUE( - outputGpu2Cpu->equalData(vector{0., 0., 0., 0., 1, 1, 1, 1})); + auto op = gCuda->addOp(inputGpu, nullptr, axis); + gCuda->dataMalloc(); + + inputGpu->setData(generator); + + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); } +TEST(CUDA_SoftmaxFP32, run) { + test_softmaxFp32( + Shape{2, 3, 2, 2}, + vector{0., 1., 2., 3., 4., 5., 6., 7., + 8., 9., 10., 11., 12., 13., 14., 15., + 16., 17., 18., 19., 20., 21., 22., 23.}, + 0, vector{6.14417422e-06, 6.14417422e-06, 6.14417422e-06, + 6.14417422e-06, 6.14417422e-06, 6.14417422e-06, + 6.14417422e-06, 6.14417422e-06, 6.14417422e-06, + 6.14417422e-06, 6.14417422e-06, 6.14417422e-06, + 9.99993801e-01, 9.99993801e-01, 9.99993801e-01, + 9.99993801e-01, 9.99993801e-01, 9.99993801e-01, + 9.99993801e-01, 9.99993801e-01, 9.99993801e-01, + 9.99993801e-01, 9.99993801e-01, 9.99993801e-01}); + test_softmaxFp32( + Shape{2, 3, 2, 2}, + vector{0., 1., 2., 3., 4., 5., 6., 7., + 8., 9., 10., 11., 12., 13., 14., 15., + 16., 17., 18., 19., 20., 21., 22., 23.}, + 1, vector{3.29320435e-04, 3.29320435e-04, 3.29320435e-04, + 3.29320435e-04, 1.79802869e-02, 1.79802869e-02, + 1.79802869e-02, 1.79802869e-02, 9.81690347e-01, + 9.81690347e-01, 9.81690347e-01, 9.81690347e-01, + 3.29320435e-04, 3.29320435e-04, 3.29320435e-04, + 3.29320435e-04, 1.79802869e-02, 1.79802869e-02, + 1.79802869e-02, 1.79802869e-02, 9.81690347e-01, + 9.81690347e-01, 9.81690347e-01, 9.81690347e-01}); + test_softmaxFp32( + Shape{2, 3, 2, 2}, + vector{0., 1., 2., 3., 4., 5., 6., 7., + 8., 9., 10., 11., 12., 13., 14., 15., + 16., 17., 18., 19., 20., 21., 22., 23.}, + 2, vector{0.11920292, 0.11920292, 0.88079703, 0.88079703, + 0.11920292, 0.11920292, 0.88079703, 0.88079703, + 0.11920292, 0.11920292, 0.88079703, 0.88079703, + 0.11920292, 0.11920292, 0.88079703, 0.88079703, + 0.11920292, 0.11920292, 0.88079703, 0.88079703, + 0.11920292, 0.11920292, 0.88079703, 0.88079703}); + test_softmaxFp32( + Shape{2, 3, 2, 2}, + vector{0., 1., 2., 3., 4., 5., 6., 7., + 8., 9., 10., 11., 12., 13., 14., 15., + 16., 17., 18., 19., 20., 21., 22., 23.}, + 3, vector{0.26894143, 0.73105860, 0.26894143, 0.73105860, + 0.26894143, 0.73105860, 0.26894143, 0.73105860, + 0.26894143, 0.73105860, 0.26894143, 0.73105860, + 0.26894143, 0.73105860, 0.26894143, 0.73105860, + 0.26894143, 0.73105860, 0.26894143, 0.73105860, + 0.26894143, 0.73105860, 0.26894143, 0.73105860}); +} // python output +TEST(CUDA_SoftmaxFP16, run) { + test_softmaxFp16(Shape{2, 3, 2, 2}, ValGenerator<2>(), 0, + vector{0.5000, 0.5000, 0.5000, 0.5000, 0.5000, + 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, + 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, + 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, + 0.5000, 0.5000, 0.5000, 0.5000}); + test_softmaxFp16( + Shape{2, 3, 2, 2}, ValGenerator<2>(), 1, // data accuracy down + vector{0.333252, 0.333252, 0.333252, 0.333252, 0.333252, + 0.333252, 0.333252, 0.333252, 0.333252, 0.333252, + 0.333252, 0.333252, 0.333252, 0.333252, 0.333252, + 0.333252, 0.333252, 0.333252, 0.333252, 0.333252, + 0.333252, 0.333252, 0.333252, 0.333252}); -TEST(cuDNN_Softmax2, run_axis1) { - // Runtime - Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); - auto cudaRuntime = make_ref(); +} // python output - // Build input data on CPU - Tensor inputCpu = - make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); - - // GPU - Graph cudaGraph = make_ref(cudaRuntime); - auto inputGpu = cudaGraph->cloneTensor(inputCpu); - auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 1); - cudaGraph->dataMalloc(); - inputGpu->setData(IncrementalGenerator()); - cudaRuntime->run(cudaGraph); - auto outputGpu = gpuOp->getOutput(); - auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); - cudaPrintTensor(outputGpu); - // Check - EXPECT_TRUE(outputGpu2Cpu->equalData(vector{ - 0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, 0.9820138, - 0.9820138, 0.9820138, 0.0179862, 0.0179862, 0.0179862, 0.0179862, - 0.9820138, 0.9820138, 0.9820138, 0.9820138})); -} - -TEST(cuDNN_Softmax2, run_axis2) { - // Runtime - Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); - auto cudaRuntime = make_ref(); - - // Build input data on CPU - Tensor inputCpu = - make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); - - // GPU - Graph cudaGraph = make_ref(cudaRuntime); - auto inputGpu = cudaGraph->cloneTensor(inputCpu); - auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 2); - cudaGraph->dataMalloc(); - inputGpu->setData(IncrementalGenerator()); - cudaRuntime->run(cudaGraph); - auto outputGpu = gpuOp->getOutput(); - auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); - cudaPrintTensor(outputGpu); - // Check - EXPECT_TRUE(outputGpu2Cpu->equalData(vector{ - 0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029, - 0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971, 0.8807971, - 0.1192029, 0.1192029, 0.8807971, 0.8807971})); -} - -TEST(cuDNN_Softmax2, run_axis3) { - // Runtime - Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); - auto cudaRuntime = make_ref(); - - // Build input data on CPU - Tensor inputCpu = - make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); - - // GPU - Graph cudaGraph = make_ref(cudaRuntime); - auto inputGpu = cudaGraph->cloneTensor(inputCpu); - auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 3); - cudaGraph->dataMalloc(); - inputGpu->setData(IncrementalGenerator()); - cudaRuntime->run(cudaGraph); - auto outputGpu = gpuOp->getOutput(); - auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); - cudaPrintTensor(outputGpu); - // Check - EXPECT_TRUE(outputGpu2Cpu->equalData(vector{ - 0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586, - 0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586, - 0.2689414, 0.7310586, 0.2689414, 0.7310586})); -} } // namespace infini diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc index 43700b77..bae607f1 100644 --- a/test/kernels/cuda/test_cuda_split.cc +++ b/test/kernels/cuda/test_cuda_split.cc @@ -130,5 +130,35 @@ TEST(Split, Cuda_dim0) { EXPECT_TRUE(o0Cpu->equalData(vector{0, 1, 2})); EXPECT_TRUE(o1Cpu->equalData(vector{3, 4, 5})); } +//---------------- +TEST(SplitFp16, CudaHigh) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float16); + gCpu->dataMalloc(); + input->setData(ValGenerator<2>()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputGpu = gCuda->cloneTensor(input); + auto op = gCuda->addOp(inputGpu, std::nullopt, 1, 3); + gCuda->dataMalloc(); + inputGpu->setData(ValGenerator<2>()); + + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + EXPECT_EQ(op->getOutputs().size(), (size_t)3); + auto o0Cpu = gCpu->cloneTensor(op->getOutput(0)); + auto o1Cpu = gCpu->cloneTensor(op->getOutput(1)); + auto o2Cpu = gCpu->cloneTensor(op->getOutput(2)); + EXPECT_TRUE(o0Cpu->equalData(vector{ + 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.})); + EXPECT_TRUE(o1Cpu->equalData(vector{ + 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.})); + EXPECT_TRUE(o2Cpu->equalData(vector{ + 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.})); +} } // namespace infini diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index 4a2e5e98..fd407dfd 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -40,6 +40,34 @@ void testUnary(const std::function &generator, EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu)); } +template +void testCast(const std::function &generator, + const Shape &shape, vector ansVec) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // GPU + Graph cudaGraph = make_ref(cudaRuntime); + auto inputGpu = cudaGraph->cloneTensor(inputCpu); + auto gpuOp = + cudaGraph->addOp(inputGpu, nullptr, CastType::Float2Float16); + cudaGraph->dataMalloc(); + inputGpu->setData(generator); + cudaRuntime->run(cudaGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + + inputCpu->printData(); + outputGpu2Cpu->printData(); + EXPECT_TRUE(outputGpu2Cpu->equalData(ansVec)); +} + TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); @@ -50,6 +78,8 @@ TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testCast(IncrementalGenerator(), Shape{8, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7}); // more shapes testUnary(IncrementalGenerator(), Shape{13}); testUnary(IncrementalGenerator(), Shape{4, 3}); diff --git a/test/kernels/cuda/test_cuda_where.cc b/test/kernels/cuda/test_cuda_where.cc index 32c2f253..07f5b48e 100644 --- a/test/kernels/cuda/test_cuda_where.cc +++ b/test/kernels/cuda/test_cuda_where.cc @@ -8,11 +8,11 @@ namespace infini { -void test_where(const Shape &inputXShape, const vector &inputXData, - const Shape &inputYShape, const vector &inputYData, - const Shape &conditionShape, - const vector &conditionData, - const vector &ExpectData) { +void test_whereFp32(const Shape &inputXShape, const vector &inputXData, + const Shape &inputYShape, const vector &inputYData, + const Shape &conditionShape, + const vector &conditionData, + const vector &ExpectData) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto condition = gCpu->addTensor(conditionShape, DataType::UInt8); @@ -43,22 +43,62 @@ void test_where(const Shape &inputXShape, const vector &inputXData, oCpu->printData(); //->printData EXPECT_TRUE(oCpu->equalData(ExpectData)); } +void test_whereFp16( + const Shape &inputXShape, + const std::function &generatorX, + const Shape &inputYShape, + const std::function &generatorY, + const Shape &conditionShape, const vector &conditionData, + const vector &ExpectData) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); -TEST(CUDA_Where, run) { - test_where( + auto inputX = gCpu->addTensor(inputXShape, DataType::Float16); + auto inputY = gCpu->addTensor(inputYShape, DataType::Float16); + auto condition = gCpu->addTensor(conditionShape, DataType::UInt8); + gCpu->dataMalloc(); + + inputX->setData(generatorX); + inputY->setData(generatorY); + condition->copyin(conditionData); // + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputXGpu = gCuda->cloneTensor(inputX); + auto inputYGpu = gCuda->cloneTensor(inputY); + auto conditionGpu = gCuda->cloneTensor(condition); + + auto op = gCuda->addOp(inputXGpu, inputYGpu, conditionGpu, + nullptr); // WhereObj + gCuda->dataMalloc(); + + inputXGpu->setData(generatorX); + inputYGpu->setData(generatorY); + conditionGpu->copyin(conditionData); + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); +} + +TEST(CUDA_WhereFp32, run) { + test_whereFp32( Shape{2, 2, 3, 1}, vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, Shape{2, 2, 3, 1}, vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Shape{2, 2, 3, 1}, vector{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1}, vector{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.}); - test_where(Shape{2, 1, 1, 3}, // inputx - vector{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy - vector{1, 1}, Shape{2, 1, 3, 1}, // condition - vector{0, 1, 1, 0, 0, 0}, - vector{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1., - 0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1., - 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); - test_where( + test_whereFp32(Shape{2, 1, 1, 3}, // inputx + vector{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy + vector{1, 1}, Shape{2, 1, 3, 1}, // condition + vector{0, 1, 1, 0, 0, 0}, + vector{1., 1., 1., 0., 1., 2., 0., 1., 2., + 1., 1., 1., 0., 1., 2., 0., 1., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + test_whereFp32( Shape{ 3, }, @@ -68,7 +108,7 @@ TEST(CUDA_Where, run) { vector{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3., 0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.}); - test_where( + test_whereFp32( Shape{ 3, }, @@ -80,6 +120,30 @@ TEST(CUDA_Where, run) { 0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.}); +} // python output +TEST(CUDA_WhereFp16, run) { + test_whereFp16( + Shape{ + 3, + }, + ValGenerator<1>(), // inputX + Shape{2, 3, 1}, ValGenerator<2>(), // inputY + Shape{2, 1, 3, 1}, vector{0, 1, 1, 0, 0, 0}, // condition + vector{2., 2., 2., 1., 1., 1., 1., 1., 1., 2., 2., 2., + 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., + 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.}); + test_whereFp16( + Shape{ + 3, + }, + ValGenerator<1>(), // inputX + Shape{2, 3, 1}, ValGenerator<2>(), // inputY + Shape{2, 1, 3, 1}, + vector{false, true, true, false, false, false}, // condition + vector{2., 2., 2., 1., 1., 1., 1., 1., 1., 2., 2., 2., + 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., + 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.}); + } // python output } // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_conv.cc b/test/kernels/intelcpu/test_mkl_conv.cc index 76ff2628..bcb3f3f4 100644 --- a/test/kernels/intelcpu/test_mkl_conv.cc +++ b/test/kernels/intelcpu/test_mkl_conv.cc @@ -53,8 +53,8 @@ TEST(mkl_Conv, tune) { mklRuntime->run(gMkl, tune); // check record - auto kernelAttrs = KernelAttrs{ - Device::INTELCPU, conv->getOpType().underlying(), DataType::Float32}; + auto kernelAttrs = + KernelAttrs{Device::INTELCPU, conv->getOpType().underlying()}; auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; std::optional perfData = PerfEngine::getInstance().getPerfData(perfKey); diff --git a/test/kernels/intelcpu/test_mkl_conv_transposed.cc b/test/kernels/intelcpu/test_mkl_conv_transposed.cc index 40a33fcd..2f5624fb 100644 --- a/test/kernels/intelcpu/test_mkl_conv_transposed.cc +++ b/test/kernels/intelcpu/test_mkl_conv_transposed.cc @@ -74,7 +74,9 @@ TEST(mkl_ConvTransposed, tune) { runtime->run(gMkl, tune); // check record auto kernelAttrs = KernelAttrs{ - Device::INTELCPU, conv->getOpType().underlying(), DataType::Float32}; + Device::INTELCPU, + conv->getOpType().underlying(), + }; auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; std::optional perfData = PerfEngine::getInstance().getPerfData(perfKey); diff --git a/test/kernels/intelcpu/test_mkl_pooling.cc b/test/kernels/intelcpu/test_mkl_pooling.cc index 5d25bb22..71b381e9 100644 --- a/test/kernels/intelcpu/test_mkl_pooling.cc +++ b/test/kernels/intelcpu/test_mkl_pooling.cc @@ -19,7 +19,7 @@ void testPoolMkl(const std::function &generator, // Build input data Tensor i0 = g->addTensor(shape, DataType::Float32); auto pool = g->addOp(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3], - kdps[4], kdps[5], kdps[6], kdps[7]); + kdps[4], kdps[5], kdps[6], kdps[7], 0); g->dataMalloc(); i0->setData(generator); diff --git a/test/kernels/intelcpu/test_mkl_reduce.cc b/test/kernels/intelcpu/test_mkl_reduce.cc index 859a1f91..e67789f1 100644 --- a/test/kernels/intelcpu/test_mkl_reduce.cc +++ b/test/kernels/intelcpu/test_mkl_reduce.cc @@ -2,7 +2,7 @@ #include "core/kernel.h" #include "core/runtime.h" #include "intelcpu/mkl_runtime.h" -#include "operators/reduce_mean.h" +#include "operators/reduce.h" #include "test.h" diff --git a/test/operators/test_unary.cc b/test/operators/test_unary.cc index 911d815e..be8be206 100644 --- a/test/operators/test_unary.cc +++ b/test/operators/test_unary.cc @@ -13,8 +13,9 @@ TEST(Unary, ShapeInference) { { Graph g = make_ref(runtime); Tensor i0 = g->addTensor({2}, DataType::Float32); - auto op = g->addOp(i0, nullptr); + auto op = g->addOp(i0, nullptr, CastType::Float2Float16); EXPECT_EQ(op->getOutput()->getDims(), (Shape{2})); + EXPECT_EQ(op->getOutDType(), (DataType::Float16)); } } diff --git a/test/operators/test_where.cc b/test/operators/test_where.cc index c32e2d81..6b90837f 100644 --- a/test/operators/test_where.cc +++ b/test/operators/test_where.cc @@ -7,7 +7,7 @@ namespace infini { -TEST(Where, ShapeInference) { +TEST(WhereFp32, ShapeInference) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(runtime); @@ -42,5 +42,39 @@ TEST(Where, ShapeInference) { EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 12, 224, 224})); } } - +TEST(WhereFp16, ShapeInference) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor x = g->addTensor({2, 2}, DataType::Float16); + Tensor y = g->addTensor({2, 2}, DataType::Float16); + Tensor con = g->addTensor({2, 2}, DataType::Bool); + auto op = g->addOp(x, y, con, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 2})); + } + { + Graph g = make_ref(runtime); + Tensor x = g->addTensor({1, 12, 224, 224}, DataType::Float16); + Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float16); + Tensor con = g->addTensor({1, 224, 1}, DataType::Bool); + auto op = g->addOp(x, y, con, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 12, 224, 224})); + } + { + Graph g = make_ref(runtime); + Tensor x = g->addTensor({12, 224, 224}, DataType::Float16); + Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float16); + Tensor con = g->addTensor({1, 224}, DataType::Bool); + auto op = g->addOp(x, y, con, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 12, 224, 224})); + } + { + Graph g = make_ref(runtime); + Tensor x = g->addTensor({12, 224, 224}, DataType::Float16); + Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float16); + Tensor con = g->addTensor({2, 1, 1, 1, 224}, DataType::Bool); + auto op = g->addOp(x, y, con, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 12, 224, 224})); + } +} } // namespace infini