From 09b2ecf98a28a5ac3bedb31ee9502484885a7580 Mon Sep 17 00:00:00 2001 From: Hardy <100662313+wanghailu0717@users.noreply.github.com> Date: Wed, 24 Jan 2024 13:33:33 +0800 Subject: [PATCH 1/3] support more data type on mlu (#211) * support more data type * clang format * fix little bug * fix cncl datatype * fix format --------- Co-authored-by: wanghailu Co-authored-by: Zhang Bolun --- include/bang/bang_common.h | 32 ++++ src/kernels/bang/activation.cc | 58 ++++--- src/kernels/bang/activation_backward.cc | 21 ++- src/kernels/bang/all_gather.cc | 9 +- src/kernels/bang/all_reduce.cc | 22 +-- src/kernels/bang/batchnorm.cc | 31 ++-- src/kernels/bang/broadcast.cc | 12 +- src/kernels/bang/ceil.cc | 13 +- src/kernels/bang/clip.cc | 7 +- src/kernels/bang/concat.cc | 16 +- src/kernels/bang/conv.cc | 36 +++-- src/kernels/bang/conv_trans.cc | 15 +- src/kernels/bang/convbpfilter.cc | 39 ++--- src/kernels/bang/det.cc | 13 +- src/kernels/bang/element_wise.cc | 194 +++++++++++------------- src/kernels/bang/erf.cc | 13 +- src/kernels/bang/exp.cc | 13 +- src/kernels/bang/fill.cc | 7 +- src/kernels/bang/floor.cc | 13 +- src/kernels/bang/gather.cc | 13 +- src/kernels/bang/hardtanh.cc | 4 +- src/kernels/bang/l2loss.cc | 4 +- src/kernels/bang/layer_norm.cc | 17 +-- src/kernels/bang/log.cc | 13 +- src/kernels/bang/lrn.cc | 13 +- src/kernels/bang/matmul.cc | 26 ++-- src/kernels/bang/negtensor.cc | 13 +- src/kernels/bang/pad.cc | 13 +- src/kernels/bang/pooling.cc | 11 +- src/kernels/bang/reciprocal.cc | 13 +- src/kernels/bang/reduce.cc | 17 +-- src/kernels/bang/rsqrt.cc | 13 +- src/kernels/bang/slice.cc | 8 +- src/kernels/bang/split.cc | 13 +- src/kernels/bang/sqrt.cc | 13 +- src/kernels/bang/transpose.cc | 26 ++-- src/kernels/bang/trigon.cc | 13 +- src/kernels/bang/where.cc | 19 ++- 38 files changed, 418 insertions(+), 408 deletions(-) diff --git a/include/bang/bang_common.h b/include/bang/bang_common.h index 16a09053..6c8047ef 100644 --- a/include/bang/bang_common.h +++ b/include/bang/bang_common.h @@ -2,6 +2,7 @@ #include "cnnl.h" #include "cnrt.h" #include "core/common.h" +#include "core/data_type.h" #define checkBangError(call) \ { \ @@ -27,4 +28,35 @@ namespace infini { using BangPtr = void *; +inline cnnlDataType_t cnnlDataTypeConvert(DataType dataType) { + if (dataType == DataType::Float32) { + return CNNL_DTYPE_FLOAT; + } + if (dataType == DataType::Float16) { + return CNNL_DTYPE_HALF; + } + if (dataType == DataType::Double) { + return CNNL_DTYPE_DOUBLE; + } + if (dataType == DataType::Int8) { + return CNNL_DTYPE_INT8; + } + if (dataType == DataType::Int32) { + return CNNL_DTYPE_INT32; + } + if (dataType == DataType::UInt8) { + return CNNL_DTYPE_UINT8; + } + if (dataType == DataType::BFloat16) { + return CNNL_DTYPE_BFLOAT16; + } + if (dataType == DataType::Int64) { + return CNNL_DTYPE_INT64; + } + if (dataType == DataType::Bool) { + return CNNL_DTYPE_BOOL; + } + return CNNL_DTYPE_INVALID; +} + } // namespace infini diff --git a/src/kernels/bang/activation.cc b/src/kernels/bang/activation.cc index bc970760..41de7534 100644 --- a/src/kernels/bang/activation.cc +++ b/src/kernels/bang/activation.cc @@ -11,7 +11,6 @@ class UnaryCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -22,13 +21,13 @@ class UnaryCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlActivationDescriptor_t opDesc; checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); checkCnnlError(cnnlSetActivationDescriptor_v2( @@ -51,7 +50,6 @@ class RoundCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -62,13 +60,13 @@ class RoundCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlRound(context->cnnlHandle(), aDesc, aData, cDesc, cData); if (stat != CNNL_STATUS_SUCCESS) @@ -82,7 +80,6 @@ class PReluCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -95,17 +92,17 @@ class PReluCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, bDim.size(), - bDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + bDim.size(), bDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlPrelu(context->cnnlHandle(), aDesc, aData, bDesc, bData, cDesc, cData); @@ -122,7 +119,6 @@ class SoftmaxCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -185,13 +181,13 @@ class SoftmaxCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, inDim.size(), - inDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + inDim.size(), inDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, outDim.size(), - outDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + outDim.size(), outDim.data())); float alpha = 1.0; float beta = 0.0; cnnlStatus_t stat = diff --git a/src/kernels/bang/activation_backward.cc b/src/kernels/bang/activation_backward.cc index c2c3baa6..0e9d8435 100644 --- a/src/kernels/bang/activation_backward.cc +++ b/src/kernels/bang/activation_backward.cc @@ -10,7 +10,6 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const yData = (op->getInputs(0)->getRawDataPtr()); @@ -25,21 +24,21 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig { auto diffxDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&yDesc)); - checkCnnlError(cnnlSetTensorDescriptor(yDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, yDim.size(), - yDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + yDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + yDim.size(), yDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&diffYDesc)); checkCnnlError(cnnlSetTensorDescriptor( - diffYDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, diffyDim.size(), - diffyDim.data())); + diffYDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + diffyDim.size(), diffyDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&xDesc)); - checkCnnlError(cnnlSetTensorDescriptor(xDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, xDim.size(), - xDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + xDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + xDim.size(), xDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&diffXDesc)); checkCnnlError(cnnlSetTensorDescriptor( - diffXDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, diffxDim.size(), - diffxDim.data())); + diffXDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + diffxDim.size(), diffxDim.data())); // get op descriptor cnnlActivationDescriptor_t opDesc; checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); diff --git a/src/kernels/bang/all_gather.cc b/src/kernels/bang/all_gather.cc index d44569fe..f35d71d4 100644 --- a/src/kernels/bang/all_gather.cc +++ b/src/kernels/bang/all_gather.cc @@ -19,18 +19,17 @@ class AllGatherCNCL : public BangKernelWithoutConfig { BangPtr output_temp = context->getWorkspace(op->getInputs(0)->getBytes() * world_size); // void *output = op->getOutput()->getRawDataPtr(); - // IT_ASSERT(op->getDType() == DataType::Float32); checkBangError(cnrtMalloc(&output_temp, op->getInputs(0)->getBytes() * world_size)); size_t bytes = op->getInputs(0)->getBytes(); - size_t count = bytes / op->getDType().getSize(); + size_t count = bytes / sizeof(uint8_t); cnclComm_t comm = dynamic_cast(context->getCommunicator()) .getCnclComm(); cnrtQueue_t queue = context->getBangQueue(); CNCL_CHECK( - cnclAllGather(input, output_temp, count, cnclFloat32, comm, queue)); + cnclAllGather(input, output_temp, count, cnclUint8, comm, queue)); checkBangError(cnrtQueueSync(queue)); for (int i = 0; i < world_size; ++i) { Tensor output = op->getOutput(i); @@ -42,8 +41,8 @@ class AllGatherCNCL : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::AllGather, DataType::Float32, - AllGatherCNCL, "AllGather_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllGather, AllGatherCNCL, + "AllGather_CNCL_BANG"); } // namespace infini #endif diff --git a/src/kernels/bang/all_reduce.cc b/src/kernels/bang/all_reduce.cc index 4e9266fb..c9e42c65 100644 --- a/src/kernels/bang/all_reduce.cc +++ b/src/kernels/bang/all_reduce.cc @@ -13,14 +13,14 @@ class AllReduceCNCL : public BangKernelWithoutConfig { auto context = dynamic_cast(_context); void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); - IT_ASSERT(op->getDType() == DataType::Float32); - size_t count = op->getInputs(0)->size(); + size_t bytes = op->getInputs(0)->getBytes(); + size_t count = bytes / sizeof(uint8_t); cnclComm_t comm = dynamic_cast(context->getCommunicator()) .getCnclComm(); cnrtQueue_t queue = context->getBangQueue(); // checkBangError(cnrtQueueSync(queue)); - CNCL_CHECK(cnclAllReduce(input, output, count, cnclFloat, getRedOp(), + CNCL_CHECK(cnclAllReduce(input, output, count, cnclUint8, getRedOp(), comm, queue)); checkBangError(cnrtQueueSync(queue)); } @@ -41,13 +41,13 @@ class AllReduceMaxCNCL : public AllReduceCNCL { cnclReduceOp_t getRedOp() const override { return cnclMax; } }; -REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, DataType::Float32, - AllReduceSumCNCL, "AllReduce_Sum_CNCL_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, DataType::Float32, - AllReduceProdCNCL, "AllReduce_Prod_CNCL_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, DataType::Float32, - AllReduceMinCNCL, "AllReduce_Min_CNCL_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, DataType::Float32, - AllReduceMaxCNCL, "AllReduce_Max_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, AllReduceSumCNCL, + "AllReduce_Sum_CNCL_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, AllReduceProdCNCL, + "AllReduce_Prod_CNCL_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, AllReduceMinCNCL, + "AllReduce_Min_CNCL_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, AllReduceMaxCNCL, + "AllReduce_Max_CNCL_BANG"); } // namespace infini #endif diff --git a/src/kernels/bang/batchnorm.cc b/src/kernels/bang/batchnorm.cc index 31aba547..633f0d88 100644 --- a/src/kernels/bang/batchnorm.cc +++ b/src/kernels/bang/batchnorm.cc @@ -7,7 +7,6 @@ class BatchNormCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const input = (op->getInputs(0)->getRawDataPtr()); @@ -33,18 +32,18 @@ class BatchNormCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&intransDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outtransDesc)); - checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, dims.size(), - dims.data())); - checkCnnlError(cnnlSetTensorDescriptor(intransDesc, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, dims.size(), - dimsTrans)); - checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, outDims.size(), - outDims.data())); - checkCnnlError(cnnlSetTensorDescriptor(outtransDesc, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, outDims.size(), - dimsOutTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + inDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + dims.size(), dims.data())); + checkCnnlError(cnnlSetTensorDescriptor( + intransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + dims.size(), dimsTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + outDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + outDims.size(), outDims.data())); + checkCnnlError(cnnlSetTensorDescriptor( + outtransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + outDims.size(), dimsOutTrans)); cnnlTransposeDescriptor_t opDesc; checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc)); checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute)); @@ -53,9 +52,9 @@ class BatchNormCnnl : public BangKernelWithoutConfig { &wsSize); BangPtr wsData = context->getWorkspace(wsSize); BangPtr inputTrans = context->getWorkspace( - cnnlGetTensorElementNum(inDesc) * sizeof(float)); + cnnlGetTensorElementNum(inDesc) * op->getDType().getSize()); BangPtr outputTrans = context->getWorkspace( - cnnlGetTensorElementNum(inDesc) * sizeof(float)); + cnnlGetTensorElementNum(inDesc) * op->getDType().getSize()); cnnlStatus_t stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, inDesc, input, intransDesc, inputTrans, wsData, wsSize); @@ -67,7 +66,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig { cnnlTensorDescriptor_t paraDesc; checkCnnlError(cnnlCreateTensorDescriptor(¶Desc)); checkCnnlError(cnnlSetTensorDescriptor( - paraDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + paraDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), dimsScaleBiasMeanVar.size(), dimsScaleBiasMeanVar.data())); float alpha = 1.f, beta = 0.f; diff --git a/src/kernels/bang/broadcast.cc b/src/kernels/bang/broadcast.cc index 411506c5..e0bc5879 100644 --- a/src/kernels/bang/broadcast.cc +++ b/src/kernels/bang/broadcast.cc @@ -13,22 +13,22 @@ class BroadcastCNCL : public BangKernelWithoutConfig { auto context = dynamic_cast(_context); void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); - IT_ASSERT(op->getDType() == DataType::Float32); - size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); + size_t bytes = op->getInputs(0)->getBytes(); + size_t count = bytes / sizeof(uint8_t); cnclComm_t comm = dynamic_cast(context->getCommunicator()) .getCnclComm(); cnrtQueue_t queue = context->getBangQueue(); // TODO: Using default stream 0 for now. - CNCL_CHECK(cnclBroadcast(input, output, count, cnclFloat32, - op->getRoot(), comm, queue)); + CNCL_CHECK(cnclBroadcast(input, output, count, cnclUint8, op->getRoot(), + comm, queue)); checkBangError(cnrtQueueSync(queue)); } }; -REGISTER_KERNEL(Device::BANG, OpType::Broadcast, DataType::Float32, - BroadcastCNCL, "Broadcast_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Broadcast, BroadcastCNCL, + "Broadcast_CNCL_BANG"); } // namespace infini #endif diff --git a/src/kernels/bang/ceil.cc b/src/kernels/bang/ceil.cc index ce77741d..61c1c98b 100644 --- a/src/kernels/bang/ceil.cc +++ b/src/kernels/bang/ceil.cc @@ -7,7 +7,6 @@ class CeilCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class CeilCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlCeil(context->cnnlHandle(), aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/clip.cc b/src/kernels/bang/clip.cc index fdb682f0..074583af 100644 --- a/src/kernels/bang/clip.cc +++ b/src/kernels/bang/clip.cc @@ -7,7 +7,6 @@ class ClipCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -19,9 +18,9 @@ class ClipCnnl : public BangKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); cnnlStatus_t stat = cnnlClip(context->cnnlHandle(), aDesc, aData, &min, &max, cData); if (stat != CNNL_STATUS_SUCCESS) diff --git a/src/kernels/bang/concat.cc b/src/kernels/bang/concat.cc index dae092c5..2350842c 100644 --- a/src/kernels/bang/concat.cc +++ b/src/kernels/bang/concat.cc @@ -7,7 +7,6 @@ class ConcatCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); int num = op->numInputs(); int axis = op->getDim(); @@ -15,17 +14,18 @@ class ConcatCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); cnnlTensorDescriptor_t desc; checkCnnlError(cnnlCreateTensorDescriptor(&desc)); - checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + desc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlTensorDescriptor_t descArray[num]; for (int i = 0; i < num; ++i) { checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i])); - checkCnnlError(cnnlSetTensorDescriptor( - descArray[i], CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, - op->getInputs(i)->getDims().size(), - op->getInputs(i)->getDims().data())); + checkCnnlError( + cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW, + cnnlDataTypeConvert(op->getDType()), + op->getInputs(i)->getDims().size(), + op->getInputs(i)->getDims().data())); } void *argv[num]; diff --git a/src/kernels/bang/conv.cc b/src/kernels/bang/conv.cc index 24d8a3fd..655a2fa7 100644 --- a/src/kernels/bang/conv.cc +++ b/src/kernels/bang/conv.cc @@ -7,7 +7,6 @@ class ConvCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -21,8 +20,9 @@ class ConvCnnl : public BangKernelWithoutConfig { cnnlConvolutionDescriptor_t convDesc; checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc)); - checkCnnlError(cnnlSetConvolutionDescriptor( - convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT)); + checkCnnlError( + cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g, + cnnlDataTypeConvert(op->getDType()))); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); @@ -55,20 +55,24 @@ class ConvCnnl : public BangKernelWithoutConfig { // get inputs checkCnnlError(cnnlCreateTensorDescriptor(&aInDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aInDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, inputs0)); + checkCnnlError(cnnlSetTensorDescriptor( + aInDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inputs0)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs0Array)); + aDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4, + inputs0Array)); checkCnnlError(cnnlCreateTensorDescriptor(&bInDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bInDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, inputs1)); + checkCnnlError(cnnlSetTensorDescriptor( + bInDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inputs1)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlSetTensorDescriptor( - bDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs1Array)); + bDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4, + inputs1Array)); int permute[4] = {0, 2, 3, 1}; cnnlTransposeDescriptor_t opDesc; @@ -80,7 +84,7 @@ class ConvCnnl : public BangKernelWithoutConfig { &wsSize); BangPtr wsData = context->getWorkspace(wsSize); BangPtr aDataOut = context->getWorkspace( - cnnlGetTensorElementNum(aInDesc) * sizeof(float)); + cnnlGetTensorElementNum(aInDesc) * op->getDType().getSize()); cnnlStatus_t stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, aInDesc, aData, aDesc, aDataOut, wsData, wsSize); @@ -91,7 +95,7 @@ class ConvCnnl : public BangKernelWithoutConfig { &wsSize); wsData = context->getWorkspace(wsSize); BangPtr bDataOut = context->getWorkspace( - cnnlGetTensorElementNum(bInDesc) * sizeof(float)); + cnnlGetTensorElementNum(bInDesc) * op->getDType().getSize()); stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, bInDesc, bData, bDesc, bDataOut, wsData, wsSize); if (stat != CNNL_STATUS_SUCCESS) @@ -100,11 +104,13 @@ class ConvCnnl : public BangKernelWithoutConfig { // get outputs checkCnnlError(cnnlCreateTensorDescriptor(&cInDesc)); checkCnnlError(cnnlSetTensorDescriptor( - cInDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, outputArray)); + cInDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4, + outputArray)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, output)); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + output)); cnnlConvolutionForwardAlgo_t algo; cnnlGetConvolutionForwardAlgorithm(context->cnnlHandle(), convDesc, @@ -116,7 +122,7 @@ class ConvCnnl : public BangKernelWithoutConfig { algo, &wsSize); wsData = context->getWorkspace(wsSize); BangPtr cDataIn = context->getWorkspace( - cnnlGetTensorElementNum(cInDesc) * sizeof(float)); + cnnlGetTensorElementNum(cInDesc) * op->getDType().getSize()); stat = cnnlConvolutionForward( context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc, diff --git a/src/kernels/bang/conv_trans.cc b/src/kernels/bang/conv_trans.cc index ce93fc9a..b2fad8ec 100644 --- a/src/kernels/bang/conv_trans.cc +++ b/src/kernels/bang/conv_trans.cc @@ -7,7 +7,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -21,8 +20,9 @@ class ConvTransCnnl : public BangKernelWithoutConfig { cnnlConvolutionDescriptor_t convDesc; checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc)); - checkCnnlError(cnnlSetConvolutionDescriptor( - convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT)); + checkCnnlError( + cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g, + cnnlDataTypeConvert(op->getDType()))); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); @@ -43,14 +43,17 @@ class ConvTransCnnl : public BangKernelWithoutConfig { // get inputs checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs0.data())); + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + dimInputs0.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlSetTensorDescriptor( - bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs1.data())); + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + dimInputs1.data())); // get outputs checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor( - cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimOutput.data())); + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + dimOutput.data())); cnnlConvolutionBwdDataAlgo_t algo; cnnlGetConvolutionBackwardDataAlgorithm( diff --git a/src/kernels/bang/convbpfilter.cc b/src/kernels/bang/convbpfilter.cc index f3e9ec94..7d2930d8 100644 --- a/src/kernels/bang/convbpfilter.cc +++ b/src/kernels/bang/convbpfilter.cc @@ -7,7 +7,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -21,8 +20,9 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { cnnlConvolutionDescriptor_t convDesc; checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc)); - checkCnnlError(cnnlSetConvolutionDescriptor( - convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT)); + checkCnnlError( + cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g, + cnnlDataTypeConvert(op->getDType()))); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); @@ -63,15 +63,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { // get inputs checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs0Array)); + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inputs0Array)); checkCnnlError(cnnlCreateTensorDescriptor(&aDescTrans)); - checkCnnlError(cnnlSetTensorDescriptor(aDescTrans, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, 4, - inputs0ArrayTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + aDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + 4, inputs0ArrayTrans)); size_t wsTrans1Size = dimInputs0[0] * dimInputs0[1] * dimInputs0[2] * - dimInputs0[3] * sizeof(float); + dimInputs0[3] * op->getDType().getSize(); BangPtr wsTrans1Data = context->getWorkspace(wsTrans1Size); cnnlStatus_t stat = @@ -82,15 +83,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlSetTensorDescriptor( - bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs1Array)); + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inputs1Array)); checkCnnlError(cnnlCreateTensorDescriptor(&bDescTrans)); - checkCnnlError(cnnlSetTensorDescriptor(bDescTrans, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, 4, - inputs1ArrayTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + bDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + 4, inputs1ArrayTrans)); size_t wsTrans2Size = dimInputs1[0] * dimInputs1[1] * dimInputs1[2] * - dimInputs1[3] * sizeof(float); + dimInputs1[3] * op->getDType().getSize(); BangPtr wsTrans2Data = context->getWorkspace(wsTrans2Size); stat = cnnlTranspose(context->cnnlHandle(), transDesc, bDesc, bData, @@ -101,15 +103,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { // get outputs checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor( - cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, outputArray)); + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + outputArray)); checkCnnlError(cnnlCreateTensorDescriptor(&cDescTrans)); - checkCnnlError(cnnlSetTensorDescriptor(cDescTrans, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, 4, - outputArrayTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + cDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + 4, outputArrayTrans)); size_t wsTrans3Size = dimOutput[0] * dimOutput[1] * dimOutput[2] * - dimOutput[3] * sizeof(float); + dimOutput[3] * op->getDType().getSize(); BangPtr wsTrans3Data = context->getWorkspace(wsTrans3Size); cnnlConvolutionBwdFilterAlgo_t algo; diff --git a/src/kernels/bang/det.cc b/src/kernels/bang/det.cc index eeb197b6..03f25041 100644 --- a/src/kernels/bang/det.cc +++ b/src/kernels/bang/det.cc @@ -7,7 +7,6 @@ class DetCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -24,14 +23,14 @@ class DetCnnl : public BangKernelWithoutConfig { auto dimout = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimin.size(), - dimin.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimin.size(), dimin.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimout.size(), - dimout.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimout.size(), dimout.data())); cnnlStatus_t stat = cnnlDet(context->cnnlHandle(), nlMode, aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/element_wise.cc b/src/kernels/bang/element_wise.cc index e919e7d1..da18a6d4 100644 --- a/src/kernels/bang/element_wise.cc +++ b/src/kernels/bang/element_wise.cc @@ -11,7 +11,6 @@ class ElementWiseCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -31,24 +30,25 @@ class ElementWiseCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); cnnlOpTensorDescriptor_t opDesc; checkCnnlError(cnnlCreateOpTensorDescriptor(&opDesc)); checkCnnlError(cnnlSetOpTensorDescriptor( - opDesc, getOpType(), CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN)); + opDesc, getOpType(), cnnlDataTypeConvert(op->getDType()), + CNNL_NOT_PROPAGATE_NAN)); size_t wsSize; cnnlGetOpTensorWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -75,7 +75,6 @@ class LogicOpCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -95,17 +94,17 @@ class LogicOpCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetLogicOpWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -129,7 +128,6 @@ class BitComputeCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -182,7 +180,6 @@ class DivCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -202,17 +199,17 @@ class DivCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -235,7 +232,6 @@ class MaximumCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -255,17 +251,17 @@ class MaximumCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetMaximumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize); @@ -287,7 +283,6 @@ class MinimumCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -307,17 +302,17 @@ class MinimumCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetMinimumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize); @@ -339,7 +334,6 @@ class MSELossCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -359,18 +353,18 @@ class MSELossCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); cnnlStatus_t stat; if (reduction == MSELossObj::None) { stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_NONE, aDesc, @@ -396,7 +390,6 @@ class PowerCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -417,17 +410,17 @@ class PowerCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetPowWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -450,7 +443,6 @@ class FloorDivCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -470,17 +462,17 @@ class FloorDivCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetFloorDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -503,7 +495,6 @@ class FloorModCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -523,17 +514,17 @@ class FloorModCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetFloorModWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -556,7 +547,6 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -576,17 +566,17 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetSquaredDifferenceWorkspaceSize(context->cnnlHandle(), aDesc, diff --git a/src/kernels/bang/erf.cc b/src/kernels/bang/erf.cc index dcf8eacd..4d4f5b73 100644 --- a/src/kernels/bang/erf.cc +++ b/src/kernels/bang/erf.cc @@ -7,7 +7,6 @@ class ErfCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class ErfCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlErf_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/exp.cc b/src/kernels/bang/exp.cc index 4b3d88ab..fbcd9485 100644 --- a/src/kernels/bang/exp.cc +++ b/src/kernels/bang/exp.cc @@ -7,7 +7,6 @@ class ExpCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class ExpCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlExp_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/fill.cc b/src/kernels/bang/fill.cc index c2de64d5..b2eebbb7 100644 --- a/src/kernels/bang/fill.cc +++ b/src/kernels/bang/fill.cc @@ -7,7 +7,6 @@ class FillCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -17,9 +16,9 @@ class FillCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlFill(context->cnnlHandle(), value, cDesc, cData); diff --git a/src/kernels/bang/floor.cc b/src/kernels/bang/floor.cc index 83d8b505..d1c4513c 100644 --- a/src/kernels/bang/floor.cc +++ b/src/kernels/bang/floor.cc @@ -7,7 +7,6 @@ class FloorCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class FloorCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlFloor(context->cnnlHandle(), aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/gather.cc b/src/kernels/bang/gather.cc index 97fa395c..63c0a872 100644 --- a/src/kernels/bang/gather.cc +++ b/src/kernels/bang/gather.cc @@ -7,7 +7,6 @@ class GatherCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -20,9 +19,9 @@ class GatherCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError( cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST)); @@ -30,9 +29,9 @@ class GatherCnnl : public BangKernelWithoutConfig { CNNL_DTYPE_INT32, bDim.size(), bDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); BangPtr wsData = context->getWorkspace(aDim.size() * 4); context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4); diff --git a/src/kernels/bang/hardtanh.cc b/src/kernels/bang/hardtanh.cc index 1f91084e..63f7d36c 100644 --- a/src/kernels/bang/hardtanh.cc +++ b/src/kernels/bang/hardtanh.cc @@ -7,7 +7,6 @@ class HardtanhCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -20,7 +19,8 @@ class HardtanhCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data())); + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + dim.size(), dim.data())); cnnlStatus_t stat = cnnlHardtanh(context->cnnlHandle(), aDesc, aData, max, min, aDesc, cData); diff --git a/src/kernels/bang/l2loss.cc b/src/kernels/bang/l2loss.cc index deb127be..659fde6b 100644 --- a/src/kernels/bang/l2loss.cc +++ b/src/kernels/bang/l2loss.cc @@ -7,7 +7,6 @@ class L2LossCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,7 +17,8 @@ class L2LossCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data())); + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + dim.size(), dim.data())); cnnlStatus_t stat = cnnlL2Loss(context->cnnlHandle(), aDesc, aData, cData); diff --git a/src/kernels/bang/layer_norm.cc b/src/kernels/bang/layer_norm.cc index acd36624..ed33bcf6 100644 --- a/src/kernels/bang/layer_norm.cc +++ b/src/kernels/bang/layer_norm.cc @@ -8,7 +8,6 @@ class LayerNormCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const inputData = (op->getInputs(0)->getRawDataPtr()); @@ -29,17 +28,17 @@ class LayerNormCnnl : public BangKernelWithoutConfig { cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc; checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); - checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, inDims.size(), - inDims.data())); + checkCnnlError(cnnlSetTensorDescriptor( + inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + inDims.size(), inDims.data())); checkCnnlError(cnnlCreateTensorDescriptor(&fiterDesc)); checkCnnlError(cnnlSetTensorDescriptor( - fiterDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, fiterDims.size(), - fiterDims.data())); + fiterDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + fiterDims.size(), fiterDims.data())); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); - checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, outDims.size(), - outDims.data())); + checkCnnlError(cnnlSetTensorDescriptor( + outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + outDims.size(), outDims.data())); size_t wsSize; cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc, &wsSize); diff --git a/src/kernels/bang/log.cc b/src/kernels/bang/log.cc index c2a3e566..74fcab72 100644 --- a/src/kernels/bang/log.cc +++ b/src/kernels/bang/log.cc @@ -7,7 +7,6 @@ class LogCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -33,13 +32,13 @@ class LogCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlLog_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/lrn.cc b/src/kernels/bang/lrn.cc index 14bca5fb..254b59ae 100644 --- a/src/kernels/bang/lrn.cc +++ b/src/kernels/bang/lrn.cc @@ -7,7 +7,6 @@ class LRNCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -20,13 +19,13 @@ class LRNCnnl : public BangKernelWithoutConfig { auto size = op->getSize(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); size_t extra_size; cnnlGetLrnExtraInputSize_v2(context->cnnlHandle(), cDesc, diff --git a/src/kernels/bang/matmul.cc b/src/kernels/bang/matmul.cc index 09780067..9afb2377 100644 --- a/src/kernels/bang/matmul.cc +++ b/src/kernels/bang/matmul.cc @@ -8,7 +8,6 @@ class MatmulCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); auto input_num = op->numInputs(); @@ -38,25 +37,26 @@ class MatmulCnnl : public BangKernelWithoutConfig { int32_t transB = op->getTransB(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError( - cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, - dimInputs0.size(), dimInputs0.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimInputs0.size(), dimInputs0.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError( - cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, - dimInputs1.size(), dimInputs1.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimInputs1.size(), dimInputs1.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError( - cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, - dimOutput.size(), dimOutput.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimOutput.size(), dimOutput.data())); if (input_num > 2) { checkCnnlError(cnnlCreateTensorDescriptor(&biasDesc)); - checkCnnlError(cnnlSetTensorDescriptor( - biasDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dimBias.size(), - dimBias.data())); + checkCnnlError( + cnnlSetTensorDescriptor(biasDesc, CNNL_LAYOUT_ARRAY, + cnnlDataTypeConvert(op->getDType()), + dimBias.size(), dimBias.data())); } cnnlMatMulDescriptor_t bmm_desc; diff --git a/src/kernels/bang/negtensor.cc b/src/kernels/bang/negtensor.cc index 12377610..170138ac 100644 --- a/src/kernels/bang/negtensor.cc +++ b/src/kernels/bang/negtensor.cc @@ -7,7 +7,6 @@ class NegTensorCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class NegTensorCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlNegTensor(context->cnnlHandle(), aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/pad.cc b/src/kernels/bang/pad.cc index e8aafa1a..c35b92e1 100644 --- a/src/kernels/bang/pad.cc +++ b/src/kernels/bang/pad.cc @@ -7,7 +7,6 @@ class PadCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -37,14 +36,14 @@ class PadCnnl : public BangKernelWithoutConfig { float paddingValue = 0.0; // input checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimIn.size(), - dimIn.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimIn.size(), dimIn.data())); // output checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimOut.size(), - dimOut.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimOut.size(), dimOut.data())); cnnlStatus_t stat = cnnlPad(context->cnnlHandle(), aDesc, aData, paddings, &paddingValue, cDesc, cData); diff --git a/src/kernels/bang/pooling.cc b/src/kernels/bang/pooling.cc index 90a0637f..95db2185 100644 --- a/src/kernels/bang/pooling.cc +++ b/src/kernels/bang/pooling.cc @@ -8,7 +8,6 @@ class PoolingCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const inData = (op->getInputs(0)->getRawDataPtr()); void *const outData = (op->getOutput()->getRawDataPtr()); @@ -20,8 +19,9 @@ class PoolingCnnl : public BangKernelWithoutConfig { int inArray[4] = {n, c, h, w}; cnnlTensorDescriptor_t inDesc; checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); - checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, inArray)); + checkCnnlError(cnnlSetTensorDescriptor( + inDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inArray)); bool mode = op->getCeilMode(); // get maxpool descriptor @@ -37,8 +37,9 @@ class PoolingCnnl : public BangKernelWithoutConfig { int outArray[4] = {outVec[0], outVec[1], outVec[2], outVec[3]}; cnnlTensorDescriptor_t outDesc; checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); - checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, outArray)); + checkCnnlError(cnnlSetTensorDescriptor( + outDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + outArray)); size_t wsSize; cnnlGetPoolingWorkspaceSize(context->cnnlHandle(), getPoolingMode(), outVec[3], outVec[2], &wsSize); diff --git a/src/kernels/bang/reciprocal.cc b/src/kernels/bang/reciprocal.cc index 7b61c2ca..1c95393a 100644 --- a/src/kernels/bang/reciprocal.cc +++ b/src/kernels/bang/reciprocal.cc @@ -7,7 +7,6 @@ class ReciprocalCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class ReciprocalCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlReciprocal(context->cnnlHandle(), aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/reduce.cc b/src/kernels/bang/reduce.cc index 810aca72..374a0d0a 100644 --- a/src/kernels/bang/reduce.cc +++ b/src/kernels/bang/reduce.cc @@ -9,7 +9,6 @@ class ReduceCnnlBase : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -26,20 +25,20 @@ class ReduceCnnlBase : public BangKernelWithoutConfig { cnnlTensorDescriptor_t inDesc, outDesc; checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); - checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); - checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, bDim.size(), - bDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + bDim.size(), bDim.data())); // get reduce descriptor cnnlReduceDescriptor_t reduceDesc; checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc)); checkCnnlError(cnnlSetReduceDescriptor_v2( reduceDesc, axes.data(), axes.size(), getReduceOp(), - CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, - CNNL_32BIT_INDICES, 0.0)); + cnnlDataTypeConvert(op->getDType()), CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0)); // get workspace size_t workspaceSize = 0; diff --git a/src/kernels/bang/rsqrt.cc b/src/kernels/bang/rsqrt.cc index 66e63e0a..f6ba8e12 100644 --- a/src/kernels/bang/rsqrt.cc +++ b/src/kernels/bang/rsqrt.cc @@ -7,7 +7,6 @@ class RsqrtCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class RsqrtCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlRsqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/slice.cc b/src/kernels/bang/slice.cc index 5cc772aa..303ce741 100644 --- a/src/kernels/bang/slice.cc +++ b/src/kernels/bang/slice.cc @@ -42,11 +42,13 @@ class SliceCnnl : public BangKernelWithoutConfig { // input checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, aDim_size, aDim_array)); + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + aDim_size, aDim_array)); // output checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor( - cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDim_size, cDim_array)); + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + cDim_size, cDim_array)); cnnlStatus_t stat = cnnlStridedSlice(context->cnnlHandle(), aDesc, aData, starts_array, @@ -59,6 +61,6 @@ class SliceCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl, +REGISTER_KERNEL(Device::BANG, OpType::Slice, SliceCnnl, "Slice_cnnl_BANG_Float32"); }; // namespace infini diff --git a/src/kernels/bang/split.cc b/src/kernels/bang/split.cc index 397b5063..d7342946 100644 --- a/src/kernels/bang/split.cc +++ b/src/kernels/bang/split.cc @@ -7,7 +7,6 @@ class SplitCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); int num = op->numOutputs(); int axis = op->getDim(); @@ -16,15 +15,17 @@ class SplitCnnl : public BangKernelWithoutConfig { cnnlTensorDescriptor_t desc; checkCnnlError(cnnlCreateTensorDescriptor(&desc)); checkCnnlError(cnnlSetTensorDescriptor( - desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data())); + desc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + dim.size(), dim.data())); cnnlTensorDescriptor_t descArray[num]; for (int i = 0; i < num; ++i) { checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i])); - checkCnnlError(cnnlSetTensorDescriptor( - descArray[i], CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, - op->getOutput(i)->getDims().size(), - op->getOutput(i)->getDims().data())); + checkCnnlError( + cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW, + cnnlDataTypeConvert(op->getDType()), + op->getOutput(i)->getDims().size(), + op->getOutput(i)->getDims().data())); } void *const inputData = (op->getInputs(0)->getRawDataPtr()); diff --git a/src/kernels/bang/sqrt.cc b/src/kernels/bang/sqrt.cc index a1ed85c9..59bc8dcb 100644 --- a/src/kernels/bang/sqrt.cc +++ b/src/kernels/bang/sqrt.cc @@ -7,7 +7,6 @@ class SqrtCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class SqrtCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlSqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/transpose.cc b/src/kernels/bang/transpose.cc index 7dedd21d..c69fa455 100644 --- a/src/kernels/bang/transpose.cc +++ b/src/kernels/bang/transpose.cc @@ -7,7 +7,6 @@ class TransposeCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class TransposeCnnl : public BangKernelWithoutConfig { auto dimout = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimin.size(), - dimin.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimin.size(), dimin.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimout.size(), - dimout.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimout.size(), dimout.data())); auto permute = op->getPermute(); cnnlTransposeDescriptor_t opDesc; @@ -53,7 +52,6 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -73,13 +71,13 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig { auto dimout = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, reshape.size(), - reshape.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + reshape.size(), reshape.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError( - cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, - transpose.size(), transpose.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + transpose.size(), transpose.data())); cnnlTransposeDescriptor_t opDesc; checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc)); diff --git a/src/kernels/bang/trigon.cc b/src/kernels/bang/trigon.cc index 989858c4..d6ce8273 100644 --- a/src/kernels/bang/trigon.cc +++ b/src/kernels/bang/trigon.cc @@ -9,7 +9,6 @@ class TrigonCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -20,13 +19,13 @@ class TrigonCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlTrigonDescriptor_t opDesc; checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc)); diff --git a/src/kernels/bang/where.cc b/src/kernels/bang/where.cc index 8786f3fd..4ed26bd7 100644 --- a/src/kernels/bang/where.cc +++ b/src/kernels/bang/where.cc @@ -7,7 +7,6 @@ class WhereCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -35,21 +34,21 @@ class WhereCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, bDim.size(), - bDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + bDim.size(), bDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_BOOL, cDim.size(), cDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&dDesc)); - checkCnnlError(cnnlSetTensorDescriptor(dDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dDim.size(), - dDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + dDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dDim.size(), dDim.data())); size_t wsSize; cnnlGetSelectV2WorkspaceSize(context->cnnlHandle(), cDesc, aDesc, bDesc, &wsSize); From a5062f3f8948f1099022d556029f75b1e28b2a31 Mon Sep 17 00:00:00 2001 From: Haojie Wang Date: Wed, 24 Jan 2024 22:16:48 +0800 Subject: [PATCH 2/3] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 714e8c77..5f1e8bd0 100644 --- a/README.md +++ b/README.md @@ -33,13 +33,14 @@ There are several configurable CMake options, see the [CMakeLists.txt](/CMakeLis ## Roadmap +- [RefactorGraph](https://github.com/InfiniTensor/RefactorGraph) is a newly designed AI framework that is set to replace the current main branch. - [EinNet](https://github.com/InfiniTensor/InfiniTensor/tree/NNET_e2e) is going to be merged into the main branch. - Integration of [PET](https://github.com/thu-pacman/PET), a tensor program optimizer supporting partially equivalent transformations. - Supported hardware - ✔ NVIDIA GPU - ✔ Cambricon MLU + - ✔ Kunlunxin XPU - ⬜ Ascend NPU - - ⬜ Kunlunxin XPU ## Contributor Guide From d1a90ba3e22906b1e7b0160acc7ec4a0eff639b1 Mon Sep 17 00:00:00 2001 From: xiaonans <51065160+xiaonans@users.noreply.github.com> Date: Thu, 25 Jan 2024 14:20:43 +0800 Subject: [PATCH 3/3] [feature] support kvcache with static graph (#209) * [feature] support kvcache with static graph * use workspace to optimize kvcache attention --------- Co-authored-by: Haojie Wang --- examples/python/llama_kvcache_inference.py | 145 +++++++++++++ include/cuda/cuda_attention_kvcache.h | 4 +- src/core/graph_handler.cc | 2 +- src/kernels/cuda/attention_kvcache.cc | 11 +- src/kernels/cuda/attention_kvcache.cu | 227 ++++++++++++--------- test/kernels/cuda/test_cuda_attention.cc | 17 +- 6 files changed, 301 insertions(+), 105 deletions(-) create mode 100644 examples/python/llama_kvcache_inference.py diff --git a/examples/python/llama_kvcache_inference.py b/examples/python/llama_kvcache_inference.py new file mode 100644 index 00000000..e6ba67ff --- /dev/null +++ b/examples/python/llama_kvcache_inference.py @@ -0,0 +1,145 @@ +import os +from pyinfinitensor.onnx import OnnxStub, backend +import numpy as np +import onnx +import torch +from transformers import LlamaModel, LlamaForCausalLM +from tqdm import tqdm +import onnx_graphsurgeon as gs +from onnxsim import simplify +import argparse + +parser = argparse.ArgumentParser(description='') +parser.add_argument('--batchsize', dest='batchsize', type=int, default=1) +parser.add_argument('--layer', dest='n_layers', type=int, default=2) +parser.add_argument('--iter', dest='n_iter', type=int, default=1) +parser.add_argument('--n_max_length', dest='n_max_length', type=int, default=1024) +parser.add_argument('--pretrained_llama_path', dest='pretrained_llama_path', type=str, + default="/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/") +parser.add_argument('--onnx_model_path', dest='onnx_model_path', type=str, + default="/data1/shared/llama") +args = parser.parse_args() + +ONNX_MODEL_PATH = "{}/llama_bs{}_layer{}.onnx".format(args.onnx_model_path, args.batchsize, args.n_layers) +ONNX_WEIGHT_PATH = "./llama_bs{}_layer{}.pb".format(args.batchsize, args.n_layers) + +def export_onnx(model: LlamaModel, ONNX_MODEL_PATH): + param = torch.zeros( + (args.batchsize, 1024), dtype=torch.long) + logits = model(param, past_key_values=None) + param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long) + + torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values, + "position_ids": param_kvcache}), ONNX_MODEL_PATH, verbose=False, + do_constant_folding=True,) + onnx_model = onnx.load(ONNX_MODEL_PATH) + print("simplifing onnx model") + onnx_model, check = simplify(onnx_model, skipped_optimizers=[ + 'eliminate_duplicate_initializer']) + assert check + + onnx.save(onnx_model, ONNX_MODEL_PATH, save_as_external_data=True, location=ONNX_WEIGHT_PATH) + print("simlifing finished.") + + +@gs.Graph.register() +def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed): + for inp in inputs: + inp.outputs.clear() + for out in outputs: + out.inputs.clear() + for inp in inputs_added: + inputs.append(inp) + for out in outputs_removed: + out.inputs.clear() + return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs) + + +def replace_onnx_with_attention_op(): + graph = gs.import_onnx( + onnx.load(ONNX_MODEL_PATH)) + tmap = graph.tensors() + for i in range(args.n_layers): + inputs = [ + tmap["onnx::Concat_" + str((i+1)*2)], + tmap["onnx::Concat_" + str((i+1)*2+1)], + tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"], + tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"], + tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]] + outputs = [ + tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"]] + + inputs_added = [graph.inputs[1]] + outputs_removed = [] + + graph.replace_with_attention( + inputs, outputs, inputs_added, outputs_removed) + + graph.outputs = [tmap[graph.outputs[0].name]] + graph.cleanup(True).toposort() + onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True) + + +if __name__ == "__main__": + kvcache_torch = None + torch_model = LlamaForCausalLM.from_pretrained( + args.pretrained_llama_path, num_hidden_layers=int(args.n_layers)).eval() + + n_heads = torch_model.config.num_attention_heads + n_dims = torch_model.config.hidden_size // n_heads + + if not os.path.exists(ONNX_MODEL_PATH): + print("exporting onnx graph") + export_onnx(torch_model, ONNX_MODEL_PATH) + replace_onnx_with_attention_op() + else: + print("will use exsiting onnx graph") + + onnx_model = onnx.load(ONNX_MODEL_PATH) + stub = OnnxStub(onnx_model, backend.cuda_runtime()) + + count_wrong = 0 + for i in tqdm(range(0, args.n_max_length)): + query = np.random.randint( + torch_model.config.vocab_size, size=(args.batchsize, 1), dtype=np.int32) + position_id = i*np.ones((args.batchsize, 1), dtype=np.int32) + + #################################### + # pytorch + #################################### + outputs_torch = torch_model( + torch.tensor(query), past_key_values=kvcache_torch) + logit_torch = outputs_torch['logits'] + kvcache_torch = outputs_torch['past_key_values'] + + #################################### + # infinitensor + #################################### + # copyin input + (list(stub.inputs.items()))[0][1].copyin_int64( + query.reshape(-1).tolist()) + (list(stub.inputs.items()))[1][1].copyin_int64( + position_id.reshape(-1).tolist()) + + stub.run() + + #################################### + # validation + #################################### + # copyout output + logits_it = np.array((list(stub.outputs.items())) + [0][1].copyout_float()) + + try: + np.testing.assert_allclose( + logit_torch[:, -1, :].detach().cpu().numpy().flatten(), logits_it, rtol=1e-3, atol=1e-3) + except Exception as e: + try: + np.testing.assert_allclose( + np.argmax(logit_torch[:, -1, :].detach().cpu().numpy().flatten()), np.argmax(logits_it), rtol=1e-3, atol=1e-3) + except: + count_wrong = count_wrong + 1 + + result = "{}/{} failed.".format(count_wrong, args.n_max_length) + print(result) + del stub diff --git a/include/cuda/cuda_attention_kvcache.h b/include/cuda/cuda_attention_kvcache.h index 880a814f..91c65d21 100644 --- a/include/cuda/cuda_attention_kvcache.h +++ b/include/cuda/cuda_attention_kvcache.h @@ -1,4 +1,5 @@ #pragma once +#include "core/common.h" #include struct AttentionKVCacheMetadata { @@ -10,6 +11,7 @@ namespace infini { void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k, float *input_v, int *position_id, float *output_matmul, - const AttentionKVCacheMetadata &compMeta); + const AttentionKVCacheMetadata &compMeta, + float *output_O_temp, float *output_sum_temp); } // namespace infini diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 415ea947..cd62ed32 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -334,7 +334,7 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, std::move(input_k_cache), std::move(input_v_cache), std::move(input_q), std::move(input_k), std::move(input_v), std::move(position_id), output_matmul); - return {output_matmul}; + return output_matmul; } else { return g ->addOp( diff --git a/src/kernels/cuda/attention_kvcache.cc b/src/kernels/cuda/attention_kvcache.cc index 52356d8d..d72e7838 100644 --- a/src/kernels/cuda/attention_kvcache.cc +++ b/src/kernels/cuda/attention_kvcache.cc @@ -21,7 +21,7 @@ class AttentionKVCacheCompute { public: void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, - Tensor output_matmul) const { + Tensor output_matmul, CudaPtr p_workspace) const { AttentionKVCacheMetadata metadata; initAttentionKVCacheMetadata(metadata, input_v_cache); @@ -32,7 +32,8 @@ class AttentionKVCacheCompute { input_v->getRawDataPtr(), position_id->getRawDataPtr(), output_matmul->getRawDataPtr(), - metadata); + metadata, (float *)p_workspace, + (float *)(p_workspace + (1ll << 30))); } }; @@ -41,10 +42,14 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute, void compute(const Operator &_op, const RuntimeObj *_context) const override { IT_ASSERT(_op->getDType() == DataType::Float32); + + size_t workspaceSize = 2ll << 30; + auto context = dynamic_cast(_context); + CudaPtr idxWsData = context->getWorkspace(workspaceSize); do_compute(_op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2], _op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5], - _op->getOutputs()[0]); + _op->getOutputs()[0], idxWsData); } }; diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu index ece6659f..f169a4b1 100644 --- a/src/kernels/cuda/attention_kvcache.cu +++ b/src/kernels/cuda/attention_kvcache.cu @@ -2,127 +2,168 @@ #include "cuda/cuda_attention_kvcache.h" #define WARP_SIZE 32 #define BLOCKSIZE WARP_SIZE -#define SEQ_UNIT 64 +#define SEQ_UNIT 32 -__global__ void _attention_kvcache_kernel(float* input_k_cache, +// ASSUME SEQ_LEN OF Q IS 1 +__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, float* input_v_cache, float* input_q, float* input_k, float* input_v, int* position_id, - float* output_matmul, - AttentionKVCacheMetadata compMeta) { + AttentionKVCacheMetadata compMeta, + float* output_O_temp, + float* output_sum_temp) { + int seq_length = position_id[0] + 1; + int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT; + if(blockIdx.y >= stride) + return; + int lane_id = threadIdx.x % WARP_SIZE; int group_id = threadIdx.x / WARP_SIZE; int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; + int idx_seq = blockIdx.y * SEQ_UNIT; if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1]) return; - float ptr_V[SEQ_UNIT*2]; - float ptr_K[SEQ_UNIT*2]; - float ptr_Q[2]; - float ptr_P[SEQ_UNIT]; + float ptr_V[SEQ_UNIT*4]; // V + float ptr_K[SEQ_UNIT*4]; // K + float ptr_Q[4]; // Q + float ptr_P[SEQ_UNIT] = {0}; - float ptr_O[2]; - float ptr_max[1]; - float ptr_sum[1]; + float ptr_O[4] = {0}; + float ptr_sum[1] = {0}; - float ptr_max_last[1]; - float ptr_sum_last[1]; - float ptr_O_last[2]; + // readin Q + (float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)]; + int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]); - (float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)]; - - int SEQ_LENGTH = position_id[0] + 1; - - int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]); - - - for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){ - ptr_max_last[0] = ptr_max[0]; - ptr_sum_last[0] = ptr_sum[0]; - (float2 &)ptr_O_last[0] = (float2 &)ptr_O[0]; + // Q*K + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ + (float4 &)ptr_K[idx_SEQ_UNIT * 4] + = (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float4 &)ptr_K[idx_SEQ_UNIT * 4] + = (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; + (float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = + (float4 &)ptr_K[idx_SEQ_UNIT * 4]; + } + #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ - (float2 &)ptr_K[idx_SEQ_UNIT * 2] - = (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; - } - else{ - (float2 &)ptr_K[idx_SEQ_UNIT * 2] - = (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; - (float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = - (float2 &)ptr_K[idx_SEQ_UNIT * 2]; - } - ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2]; - ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1]; - + for (int i = 0; i < 4; i ++){ + ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i]; #pragma unroll for (int offset = 16; offset > 0; offset /= 2) { - ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset); + ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset); } - ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2]; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2){ - ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset); - } - ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)]; + ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i]; } - - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); - ptr_P[idx_SEQ_UNIT] /= 8; - ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]); - } - ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]); - - ptr_sum[0] = 0; - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]); - ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; - } - ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0]; - - ptr_O[0] = 0; - ptr_O[1] = 0; - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ - (float2 &)ptr_V[idx_SEQ_UNIT * 2] - = (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; - } - else{ - (float2 &)ptr_V[idx_SEQ_UNIT * 2] - = (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; - (float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = - (float2 &)ptr_V[idx_SEQ_UNIT * 2]; - } - - ptr_P[idx_SEQ_UNIT] /= ptr_sum[0]; - - ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]); - ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]); - } - ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; - ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; } - (float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0]; + + // div sqrt(d) + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); + ptr_P[idx_SEQ_UNIT] /= sqrt(128.0); + } + + // softmax + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]); + ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; + } + + // * V + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ + (float4 &)ptr_V[idx_SEQ_UNIT * 4] + = (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float4 &)ptr_V[idx_SEQ_UNIT * 4] + = (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; + (float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] + = (float4 &)ptr_V[idx_SEQ_UNIT * 4]; + } + + #pragma unroll + for (int i = 0; i < 4; i ++) + ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]); + } + + #pragma unroll + for (int i = 0; i < 4; i ++) + ptr_O[i] /= ptr_sum[0]; + + (float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0]; + if(threadIdx.x == 0){ + output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0]; + } + } +__global__ void _attention_kvcache_kernel_128_2(int* position_id, + float* output_matmul, + AttentionKVCacheMetadata compMeta, + float* output_O_temp, + float* output_sum_temp) { + int lane_id = threadIdx.x % WARP_SIZE; + int group_id = threadIdx.x / WARP_SIZE; + int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; + + float ptr_O[4] = {0}; + float ptr_O_sum[4] = {0}; + float ptr_sum = 0; + float ptr_sum_temp; + int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT; + + #pragma unroll + for(int i = 0; i < size; i ++){ + (float4 &)ptr_O[0] + = (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]; + ptr_sum_temp = output_sum_temp[i + parallel_idx * size]; + + #pragma unroll + for(int k = 0; k < 4; k ++) + ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp; + ptr_sum += ptr_sum_temp; + } + + #pragma unroll + for(int k = 0; k < 4; k ++) + ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum; + + (float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0]; + +} + + namespace infini { -void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k, - float *input_v, int *position_id, float *output_matmul, - const AttentionKVCacheMetadata &compMeta) { - IT_ASSERT(compMeta.dimSize[3] == 64); - dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1); +void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, + float *input_q, float *input_k, + float *input_v, int *position_id, float *output_matmul, + const AttentionKVCacheMetadata &compMeta, + float *output_O_temp, float *output_sum_temp) { + IT_ASSERT(compMeta.dimSize[3] == 128); + + int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT; + dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y); dim3 blockDim(BLOCKSIZE, 1); - _attention_kvcache_kernel<<>>( - input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta); + assert(compMeta.dimSize[3] == 128); + _attention_kvcache_kernel_128_1<<>>( + input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, + compMeta, output_O_temp, output_sum_temp); + _attention_kvcache_kernel_128_2<<>>( + position_id, output_matmul, compMeta, output_O_temp, output_sum_temp); + } } // namespace infini diff --git a/test/kernels/cuda/test_cuda_attention.cc b/test/kernels/cuda/test_cuda_attention.cc index 3ccf861d..3a9bff45 100644 --- a/test/kernels/cuda/test_cuda_attention.cc +++ b/test/kernels/cuda/test_cuda_attention.cc @@ -14,11 +14,11 @@ TEST(AttentionKVCache, Cuda) { auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); - auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); + auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_q_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_k_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_v_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32); auto op = gCuda->addOp( @@ -32,11 +32,14 @@ TEST(AttentionKVCache, Cuda) { position_id_d->setData(IncrementalGenerator()); cudaRuntime->run(gCuda); - auto oCpu = gCpu->cloneTensor(op->getOutput()); + auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]); EXPECT_TRUE(oCpu->equalData(vector{ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); } } // namespace infini