diff --git a/include/bang/bang_common.h b/include/bang/bang_common.h index 16a09053..6c8047ef 100644 --- a/include/bang/bang_common.h +++ b/include/bang/bang_common.h @@ -2,6 +2,7 @@ #include "cnnl.h" #include "cnrt.h" #include "core/common.h" +#include "core/data_type.h" #define checkBangError(call) \ { \ @@ -27,4 +28,35 @@ namespace infini { using BangPtr = void *; +inline cnnlDataType_t cnnlDataTypeConvert(DataType dataType) { + if (dataType == DataType::Float32) { + return CNNL_DTYPE_FLOAT; + } + if (dataType == DataType::Float16) { + return CNNL_DTYPE_HALF; + } + if (dataType == DataType::Double) { + return CNNL_DTYPE_DOUBLE; + } + if (dataType == DataType::Int8) { + return CNNL_DTYPE_INT8; + } + if (dataType == DataType::Int32) { + return CNNL_DTYPE_INT32; + } + if (dataType == DataType::UInt8) { + return CNNL_DTYPE_UINT8; + } + if (dataType == DataType::BFloat16) { + return CNNL_DTYPE_BFLOAT16; + } + if (dataType == DataType::Int64) { + return CNNL_DTYPE_INT64; + } + if (dataType == DataType::Bool) { + return CNNL_DTYPE_BOOL; + } + return CNNL_DTYPE_INVALID; +} + } // namespace infini diff --git a/src/kernels/bang/activation.cc b/src/kernels/bang/activation.cc index bc970760..41de7534 100644 --- a/src/kernels/bang/activation.cc +++ b/src/kernels/bang/activation.cc @@ -11,7 +11,6 @@ class UnaryCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -22,13 +21,13 @@ class UnaryCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlActivationDescriptor_t opDesc; checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); checkCnnlError(cnnlSetActivationDescriptor_v2( @@ -51,7 +50,6 @@ class RoundCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -62,13 +60,13 @@ class RoundCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlRound(context->cnnlHandle(), aDesc, aData, cDesc, cData); if (stat != CNNL_STATUS_SUCCESS) @@ -82,7 +80,6 @@ class PReluCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -95,17 +92,17 @@ class PReluCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, bDim.size(), - bDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + bDim.size(), bDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlPrelu(context->cnnlHandle(), aDesc, aData, bDesc, bData, cDesc, cData); @@ -122,7 +119,6 @@ class SoftmaxCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -185,13 +181,13 @@ class SoftmaxCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, inDim.size(), - inDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + inDim.size(), inDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, outDim.size(), - outDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + outDim.size(), outDim.data())); float alpha = 1.0; float beta = 0.0; cnnlStatus_t stat = diff --git a/src/kernels/bang/activation_backward.cc b/src/kernels/bang/activation_backward.cc index c2c3baa6..0e9d8435 100644 --- a/src/kernels/bang/activation_backward.cc +++ b/src/kernels/bang/activation_backward.cc @@ -10,7 +10,6 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const yData = (op->getInputs(0)->getRawDataPtr()); @@ -25,21 +24,21 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig { auto diffxDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&yDesc)); - checkCnnlError(cnnlSetTensorDescriptor(yDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, yDim.size(), - yDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + yDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + yDim.size(), yDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&diffYDesc)); checkCnnlError(cnnlSetTensorDescriptor( - diffYDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, diffyDim.size(), - diffyDim.data())); + diffYDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + diffyDim.size(), diffyDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&xDesc)); - checkCnnlError(cnnlSetTensorDescriptor(xDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, xDim.size(), - xDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + xDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + xDim.size(), xDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&diffXDesc)); checkCnnlError(cnnlSetTensorDescriptor( - diffXDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, diffxDim.size(), - diffxDim.data())); + diffXDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + diffxDim.size(), diffxDim.data())); // get op descriptor cnnlActivationDescriptor_t opDesc; checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); diff --git a/src/kernels/bang/all_gather.cc b/src/kernels/bang/all_gather.cc index d44569fe..f35d71d4 100644 --- a/src/kernels/bang/all_gather.cc +++ b/src/kernels/bang/all_gather.cc @@ -19,18 +19,17 @@ class AllGatherCNCL : public BangKernelWithoutConfig { BangPtr output_temp = context->getWorkspace(op->getInputs(0)->getBytes() * world_size); // void *output = op->getOutput()->getRawDataPtr(); - // IT_ASSERT(op->getDType() == DataType::Float32); checkBangError(cnrtMalloc(&output_temp, op->getInputs(0)->getBytes() * world_size)); size_t bytes = op->getInputs(0)->getBytes(); - size_t count = bytes / op->getDType().getSize(); + size_t count = bytes / sizeof(uint8_t); cnclComm_t comm = dynamic_cast(context->getCommunicator()) .getCnclComm(); cnrtQueue_t queue = context->getBangQueue(); CNCL_CHECK( - cnclAllGather(input, output_temp, count, cnclFloat32, comm, queue)); + cnclAllGather(input, output_temp, count, cnclUint8, comm, queue)); checkBangError(cnrtQueueSync(queue)); for (int i = 0; i < world_size; ++i) { Tensor output = op->getOutput(i); @@ -42,8 +41,8 @@ class AllGatherCNCL : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::AllGather, DataType::Float32, - AllGatherCNCL, "AllGather_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllGather, AllGatherCNCL, + "AllGather_CNCL_BANG"); } // namespace infini #endif diff --git a/src/kernels/bang/all_reduce.cc b/src/kernels/bang/all_reduce.cc index 4e9266fb..c9e42c65 100644 --- a/src/kernels/bang/all_reduce.cc +++ b/src/kernels/bang/all_reduce.cc @@ -13,14 +13,14 @@ class AllReduceCNCL : public BangKernelWithoutConfig { auto context = dynamic_cast(_context); void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); - IT_ASSERT(op->getDType() == DataType::Float32); - size_t count = op->getInputs(0)->size(); + size_t bytes = op->getInputs(0)->getBytes(); + size_t count = bytes / sizeof(uint8_t); cnclComm_t comm = dynamic_cast(context->getCommunicator()) .getCnclComm(); cnrtQueue_t queue = context->getBangQueue(); // checkBangError(cnrtQueueSync(queue)); - CNCL_CHECK(cnclAllReduce(input, output, count, cnclFloat, getRedOp(), + CNCL_CHECK(cnclAllReduce(input, output, count, cnclUint8, getRedOp(), comm, queue)); checkBangError(cnrtQueueSync(queue)); } @@ -41,13 +41,13 @@ class AllReduceMaxCNCL : public AllReduceCNCL { cnclReduceOp_t getRedOp() const override { return cnclMax; } }; -REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, DataType::Float32, - AllReduceSumCNCL, "AllReduce_Sum_CNCL_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, DataType::Float32, - AllReduceProdCNCL, "AllReduce_Prod_CNCL_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, DataType::Float32, - AllReduceMinCNCL, "AllReduce_Min_CNCL_BANG_Float32"); -REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, DataType::Float32, - AllReduceMaxCNCL, "AllReduce_Max_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, AllReduceSumCNCL, + "AllReduce_Sum_CNCL_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, AllReduceProdCNCL, + "AllReduce_Prod_CNCL_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, AllReduceMinCNCL, + "AllReduce_Min_CNCL_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, AllReduceMaxCNCL, + "AllReduce_Max_CNCL_BANG"); } // namespace infini #endif diff --git a/src/kernels/bang/batchnorm.cc b/src/kernels/bang/batchnorm.cc index 31aba547..633f0d88 100644 --- a/src/kernels/bang/batchnorm.cc +++ b/src/kernels/bang/batchnorm.cc @@ -7,7 +7,6 @@ class BatchNormCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const input = (op->getInputs(0)->getRawDataPtr()); @@ -33,18 +32,18 @@ class BatchNormCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&intransDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outtransDesc)); - checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, dims.size(), - dims.data())); - checkCnnlError(cnnlSetTensorDescriptor(intransDesc, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, dims.size(), - dimsTrans)); - checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, outDims.size(), - outDims.data())); - checkCnnlError(cnnlSetTensorDescriptor(outtransDesc, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, outDims.size(), - dimsOutTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + inDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + dims.size(), dims.data())); + checkCnnlError(cnnlSetTensorDescriptor( + intransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + dims.size(), dimsTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + outDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + outDims.size(), outDims.data())); + checkCnnlError(cnnlSetTensorDescriptor( + outtransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + outDims.size(), dimsOutTrans)); cnnlTransposeDescriptor_t opDesc; checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc)); checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute)); @@ -53,9 +52,9 @@ class BatchNormCnnl : public BangKernelWithoutConfig { &wsSize); BangPtr wsData = context->getWorkspace(wsSize); BangPtr inputTrans = context->getWorkspace( - cnnlGetTensorElementNum(inDesc) * sizeof(float)); + cnnlGetTensorElementNum(inDesc) * op->getDType().getSize()); BangPtr outputTrans = context->getWorkspace( - cnnlGetTensorElementNum(inDesc) * sizeof(float)); + cnnlGetTensorElementNum(inDesc) * op->getDType().getSize()); cnnlStatus_t stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, inDesc, input, intransDesc, inputTrans, wsData, wsSize); @@ -67,7 +66,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig { cnnlTensorDescriptor_t paraDesc; checkCnnlError(cnnlCreateTensorDescriptor(¶Desc)); checkCnnlError(cnnlSetTensorDescriptor( - paraDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + paraDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), dimsScaleBiasMeanVar.size(), dimsScaleBiasMeanVar.data())); float alpha = 1.f, beta = 0.f; diff --git a/src/kernels/bang/broadcast.cc b/src/kernels/bang/broadcast.cc index 411506c5..e0bc5879 100644 --- a/src/kernels/bang/broadcast.cc +++ b/src/kernels/bang/broadcast.cc @@ -13,22 +13,22 @@ class BroadcastCNCL : public BangKernelWithoutConfig { auto context = dynamic_cast(_context); void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); - IT_ASSERT(op->getDType() == DataType::Float32); - size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); + size_t bytes = op->getInputs(0)->getBytes(); + size_t count = bytes / sizeof(uint8_t); cnclComm_t comm = dynamic_cast(context->getCommunicator()) .getCnclComm(); cnrtQueue_t queue = context->getBangQueue(); // TODO: Using default stream 0 for now. - CNCL_CHECK(cnclBroadcast(input, output, count, cnclFloat32, - op->getRoot(), comm, queue)); + CNCL_CHECK(cnclBroadcast(input, output, count, cnclUint8, op->getRoot(), + comm, queue)); checkBangError(cnrtQueueSync(queue)); } }; -REGISTER_KERNEL(Device::BANG, OpType::Broadcast, DataType::Float32, - BroadcastCNCL, "Broadcast_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Broadcast, BroadcastCNCL, + "Broadcast_CNCL_BANG"); } // namespace infini #endif diff --git a/src/kernels/bang/ceil.cc b/src/kernels/bang/ceil.cc index ce77741d..61c1c98b 100644 --- a/src/kernels/bang/ceil.cc +++ b/src/kernels/bang/ceil.cc @@ -7,7 +7,6 @@ class CeilCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class CeilCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlCeil(context->cnnlHandle(), aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/clip.cc b/src/kernels/bang/clip.cc index fdb682f0..074583af 100644 --- a/src/kernels/bang/clip.cc +++ b/src/kernels/bang/clip.cc @@ -7,7 +7,6 @@ class ClipCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -19,9 +18,9 @@ class ClipCnnl : public BangKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); cnnlStatus_t stat = cnnlClip(context->cnnlHandle(), aDesc, aData, &min, &max, cData); if (stat != CNNL_STATUS_SUCCESS) diff --git a/src/kernels/bang/concat.cc b/src/kernels/bang/concat.cc index dae092c5..2350842c 100644 --- a/src/kernels/bang/concat.cc +++ b/src/kernels/bang/concat.cc @@ -7,7 +7,6 @@ class ConcatCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); int num = op->numInputs(); int axis = op->getDim(); @@ -15,17 +14,18 @@ class ConcatCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); cnnlTensorDescriptor_t desc; checkCnnlError(cnnlCreateTensorDescriptor(&desc)); - checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + desc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlTensorDescriptor_t descArray[num]; for (int i = 0; i < num; ++i) { checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i])); - checkCnnlError(cnnlSetTensorDescriptor( - descArray[i], CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, - op->getInputs(i)->getDims().size(), - op->getInputs(i)->getDims().data())); + checkCnnlError( + cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW, + cnnlDataTypeConvert(op->getDType()), + op->getInputs(i)->getDims().size(), + op->getInputs(i)->getDims().data())); } void *argv[num]; diff --git a/src/kernels/bang/conv.cc b/src/kernels/bang/conv.cc index 24d8a3fd..655a2fa7 100644 --- a/src/kernels/bang/conv.cc +++ b/src/kernels/bang/conv.cc @@ -7,7 +7,6 @@ class ConvCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -21,8 +20,9 @@ class ConvCnnl : public BangKernelWithoutConfig { cnnlConvolutionDescriptor_t convDesc; checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc)); - checkCnnlError(cnnlSetConvolutionDescriptor( - convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT)); + checkCnnlError( + cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g, + cnnlDataTypeConvert(op->getDType()))); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); @@ -55,20 +55,24 @@ class ConvCnnl : public BangKernelWithoutConfig { // get inputs checkCnnlError(cnnlCreateTensorDescriptor(&aInDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aInDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, inputs0)); + checkCnnlError(cnnlSetTensorDescriptor( + aInDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inputs0)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs0Array)); + aDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4, + inputs0Array)); checkCnnlError(cnnlCreateTensorDescriptor(&bInDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bInDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, inputs1)); + checkCnnlError(cnnlSetTensorDescriptor( + bInDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inputs1)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlSetTensorDescriptor( - bDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs1Array)); + bDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4, + inputs1Array)); int permute[4] = {0, 2, 3, 1}; cnnlTransposeDescriptor_t opDesc; @@ -80,7 +84,7 @@ class ConvCnnl : public BangKernelWithoutConfig { &wsSize); BangPtr wsData = context->getWorkspace(wsSize); BangPtr aDataOut = context->getWorkspace( - cnnlGetTensorElementNum(aInDesc) * sizeof(float)); + cnnlGetTensorElementNum(aInDesc) * op->getDType().getSize()); cnnlStatus_t stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, aInDesc, aData, aDesc, aDataOut, wsData, wsSize); @@ -91,7 +95,7 @@ class ConvCnnl : public BangKernelWithoutConfig { &wsSize); wsData = context->getWorkspace(wsSize); BangPtr bDataOut = context->getWorkspace( - cnnlGetTensorElementNum(bInDesc) * sizeof(float)); + cnnlGetTensorElementNum(bInDesc) * op->getDType().getSize()); stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, bInDesc, bData, bDesc, bDataOut, wsData, wsSize); if (stat != CNNL_STATUS_SUCCESS) @@ -100,11 +104,13 @@ class ConvCnnl : public BangKernelWithoutConfig { // get outputs checkCnnlError(cnnlCreateTensorDescriptor(&cInDesc)); checkCnnlError(cnnlSetTensorDescriptor( - cInDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, outputArray)); + cInDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4, + outputArray)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, output)); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + output)); cnnlConvolutionForwardAlgo_t algo; cnnlGetConvolutionForwardAlgorithm(context->cnnlHandle(), convDesc, @@ -116,7 +122,7 @@ class ConvCnnl : public BangKernelWithoutConfig { algo, &wsSize); wsData = context->getWorkspace(wsSize); BangPtr cDataIn = context->getWorkspace( - cnnlGetTensorElementNum(cInDesc) * sizeof(float)); + cnnlGetTensorElementNum(cInDesc) * op->getDType().getSize()); stat = cnnlConvolutionForward( context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc, diff --git a/src/kernels/bang/conv_trans.cc b/src/kernels/bang/conv_trans.cc index ce93fc9a..b2fad8ec 100644 --- a/src/kernels/bang/conv_trans.cc +++ b/src/kernels/bang/conv_trans.cc @@ -7,7 +7,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -21,8 +20,9 @@ class ConvTransCnnl : public BangKernelWithoutConfig { cnnlConvolutionDescriptor_t convDesc; checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc)); - checkCnnlError(cnnlSetConvolutionDescriptor( - convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT)); + checkCnnlError( + cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g, + cnnlDataTypeConvert(op->getDType()))); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); @@ -43,14 +43,17 @@ class ConvTransCnnl : public BangKernelWithoutConfig { // get inputs checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs0.data())); + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + dimInputs0.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlSetTensorDescriptor( - bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs1.data())); + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + dimInputs1.data())); // get outputs checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor( - cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimOutput.data())); + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + dimOutput.data())); cnnlConvolutionBwdDataAlgo_t algo; cnnlGetConvolutionBackwardDataAlgorithm( diff --git a/src/kernels/bang/convbpfilter.cc b/src/kernels/bang/convbpfilter.cc index f3e9ec94..7d2930d8 100644 --- a/src/kernels/bang/convbpfilter.cc +++ b/src/kernels/bang/convbpfilter.cc @@ -7,7 +7,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); @@ -21,8 +20,9 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { cnnlConvolutionDescriptor_t convDesc; checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc)); - checkCnnlError(cnnlSetConvolutionDescriptor( - convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT)); + checkCnnlError( + cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g, + cnnlDataTypeConvert(op->getDType()))); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); @@ -63,15 +63,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { // get inputs checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs0Array)); + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inputs0Array)); checkCnnlError(cnnlCreateTensorDescriptor(&aDescTrans)); - checkCnnlError(cnnlSetTensorDescriptor(aDescTrans, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, 4, - inputs0ArrayTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + aDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + 4, inputs0ArrayTrans)); size_t wsTrans1Size = dimInputs0[0] * dimInputs0[1] * dimInputs0[2] * - dimInputs0[3] * sizeof(float); + dimInputs0[3] * op->getDType().getSize(); BangPtr wsTrans1Data = context->getWorkspace(wsTrans1Size); cnnlStatus_t stat = @@ -82,15 +83,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlSetTensorDescriptor( - bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs1Array)); + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inputs1Array)); checkCnnlError(cnnlCreateTensorDescriptor(&bDescTrans)); - checkCnnlError(cnnlSetTensorDescriptor(bDescTrans, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, 4, - inputs1ArrayTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + bDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + 4, inputs1ArrayTrans)); size_t wsTrans2Size = dimInputs1[0] * dimInputs1[1] * dimInputs1[2] * - dimInputs1[3] * sizeof(float); + dimInputs1[3] * op->getDType().getSize(); BangPtr wsTrans2Data = context->getWorkspace(wsTrans2Size); stat = cnnlTranspose(context->cnnlHandle(), transDesc, bDesc, bData, @@ -101,15 +103,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { // get outputs checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor( - cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, outputArray)); + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + outputArray)); checkCnnlError(cnnlCreateTensorDescriptor(&cDescTrans)); - checkCnnlError(cnnlSetTensorDescriptor(cDescTrans, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, 4, - outputArrayTrans)); + checkCnnlError(cnnlSetTensorDescriptor( + cDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), + 4, outputArrayTrans)); size_t wsTrans3Size = dimOutput[0] * dimOutput[1] * dimOutput[2] * - dimOutput[3] * sizeof(float); + dimOutput[3] * op->getDType().getSize(); BangPtr wsTrans3Data = context->getWorkspace(wsTrans3Size); cnnlConvolutionBwdFilterAlgo_t algo; diff --git a/src/kernels/bang/det.cc b/src/kernels/bang/det.cc index eeb197b6..03f25041 100644 --- a/src/kernels/bang/det.cc +++ b/src/kernels/bang/det.cc @@ -7,7 +7,6 @@ class DetCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -24,14 +23,14 @@ class DetCnnl : public BangKernelWithoutConfig { auto dimout = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimin.size(), - dimin.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimin.size(), dimin.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimout.size(), - dimout.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimout.size(), dimout.data())); cnnlStatus_t stat = cnnlDet(context->cnnlHandle(), nlMode, aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/element_wise.cc b/src/kernels/bang/element_wise.cc index e919e7d1..da18a6d4 100644 --- a/src/kernels/bang/element_wise.cc +++ b/src/kernels/bang/element_wise.cc @@ -11,7 +11,6 @@ class ElementWiseCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -31,24 +30,25 @@ class ElementWiseCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); cnnlOpTensorDescriptor_t opDesc; checkCnnlError(cnnlCreateOpTensorDescriptor(&opDesc)); checkCnnlError(cnnlSetOpTensorDescriptor( - opDesc, getOpType(), CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN)); + opDesc, getOpType(), cnnlDataTypeConvert(op->getDType()), + CNNL_NOT_PROPAGATE_NAN)); size_t wsSize; cnnlGetOpTensorWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -75,7 +75,6 @@ class LogicOpCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -95,17 +94,17 @@ class LogicOpCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetLogicOpWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -129,7 +128,6 @@ class BitComputeCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -182,7 +180,6 @@ class DivCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -202,17 +199,17 @@ class DivCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -235,7 +232,6 @@ class MaximumCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -255,17 +251,17 @@ class MaximumCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetMaximumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize); @@ -287,7 +283,6 @@ class MinimumCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -307,17 +302,17 @@ class MinimumCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetMinimumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize); @@ -339,7 +334,6 @@ class MSELossCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -359,18 +353,18 @@ class MSELossCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); cnnlStatus_t stat; if (reduction == MSELossObj::None) { stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_NONE, aDesc, @@ -396,7 +390,6 @@ class PowerCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -417,17 +410,17 @@ class PowerCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetPowWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -450,7 +443,6 @@ class FloorDivCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -470,17 +462,17 @@ class FloorDivCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetFloorDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -503,7 +495,6 @@ class FloorModCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -523,17 +514,17 @@ class FloorModCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetFloorModWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, @@ -556,7 +547,6 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -576,17 +566,17 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, a_dim.size(), - a_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + a_dim.size(), a_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, b_dim.size(), - b_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + b_dim.size(), b_dim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, c_dim.size(), - c_dim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + c_dim.size(), c_dim.data())); size_t wsSize; cnnlGetSquaredDifferenceWorkspaceSize(context->cnnlHandle(), aDesc, diff --git a/src/kernels/bang/erf.cc b/src/kernels/bang/erf.cc index dcf8eacd..4d4f5b73 100644 --- a/src/kernels/bang/erf.cc +++ b/src/kernels/bang/erf.cc @@ -7,7 +7,6 @@ class ErfCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class ErfCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlErf_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/exp.cc b/src/kernels/bang/exp.cc index 4b3d88ab..fbcd9485 100644 --- a/src/kernels/bang/exp.cc +++ b/src/kernels/bang/exp.cc @@ -7,7 +7,6 @@ class ExpCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class ExpCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlExp_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/fill.cc b/src/kernels/bang/fill.cc index c2de64d5..b2eebbb7 100644 --- a/src/kernels/bang/fill.cc +++ b/src/kernels/bang/fill.cc @@ -7,7 +7,6 @@ class FillCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -17,9 +16,9 @@ class FillCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlFill(context->cnnlHandle(), value, cDesc, cData); diff --git a/src/kernels/bang/floor.cc b/src/kernels/bang/floor.cc index 83d8b505..d1c4513c 100644 --- a/src/kernels/bang/floor.cc +++ b/src/kernels/bang/floor.cc @@ -7,7 +7,6 @@ class FloorCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class FloorCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlFloor(context->cnnlHandle(), aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/gather.cc b/src/kernels/bang/gather.cc index 97fa395c..63c0a872 100644 --- a/src/kernels/bang/gather.cc +++ b/src/kernels/bang/gather.cc @@ -7,7 +7,6 @@ class GatherCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -20,9 +19,9 @@ class GatherCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError( cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST)); @@ -30,9 +29,9 @@ class GatherCnnl : public BangKernelWithoutConfig { CNNL_DTYPE_INT32, bDim.size(), bDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); BangPtr wsData = context->getWorkspace(aDim.size() * 4); context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4); diff --git a/src/kernels/bang/hardtanh.cc b/src/kernels/bang/hardtanh.cc index 1f91084e..63f7d36c 100644 --- a/src/kernels/bang/hardtanh.cc +++ b/src/kernels/bang/hardtanh.cc @@ -7,7 +7,6 @@ class HardtanhCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -20,7 +19,8 @@ class HardtanhCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data())); + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + dim.size(), dim.data())); cnnlStatus_t stat = cnnlHardtanh(context->cnnlHandle(), aDesc, aData, max, min, aDesc, cData); diff --git a/src/kernels/bang/l2loss.cc b/src/kernels/bang/l2loss.cc index deb127be..659fde6b 100644 --- a/src/kernels/bang/l2loss.cc +++ b/src/kernels/bang/l2loss.cc @@ -7,7 +7,6 @@ class L2LossCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,7 +17,8 @@ class L2LossCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data())); + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + dim.size(), dim.data())); cnnlStatus_t stat = cnnlL2Loss(context->cnnlHandle(), aDesc, aData, cData); diff --git a/src/kernels/bang/layer_norm.cc b/src/kernels/bang/layer_norm.cc index acd36624..ed33bcf6 100644 --- a/src/kernels/bang/layer_norm.cc +++ b/src/kernels/bang/layer_norm.cc @@ -8,7 +8,6 @@ class LayerNormCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const inputData = (op->getInputs(0)->getRawDataPtr()); @@ -29,17 +28,17 @@ class LayerNormCnnl : public BangKernelWithoutConfig { cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc; checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); - checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, inDims.size(), - inDims.data())); + checkCnnlError(cnnlSetTensorDescriptor( + inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + inDims.size(), inDims.data())); checkCnnlError(cnnlCreateTensorDescriptor(&fiterDesc)); checkCnnlError(cnnlSetTensorDescriptor( - fiterDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, fiterDims.size(), - fiterDims.data())); + fiterDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + fiterDims.size(), fiterDims.data())); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); - checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, outDims.size(), - outDims.data())); + checkCnnlError(cnnlSetTensorDescriptor( + outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + outDims.size(), outDims.data())); size_t wsSize; cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc, &wsSize); diff --git a/src/kernels/bang/log.cc b/src/kernels/bang/log.cc index c2a3e566..74fcab72 100644 --- a/src/kernels/bang/log.cc +++ b/src/kernels/bang/log.cc @@ -7,7 +7,6 @@ class LogCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -33,13 +32,13 @@ class LogCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlLog_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/lrn.cc b/src/kernels/bang/lrn.cc index 14bca5fb..254b59ae 100644 --- a/src/kernels/bang/lrn.cc +++ b/src/kernels/bang/lrn.cc @@ -7,7 +7,6 @@ class LRNCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -20,13 +19,13 @@ class LRNCnnl : public BangKernelWithoutConfig { auto size = op->getSize(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); size_t extra_size; cnnlGetLrnExtraInputSize_v2(context->cnnlHandle(), cDesc, diff --git a/src/kernels/bang/matmul.cc b/src/kernels/bang/matmul.cc index 09780067..9afb2377 100644 --- a/src/kernels/bang/matmul.cc +++ b/src/kernels/bang/matmul.cc @@ -8,7 +8,6 @@ class MatmulCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); auto input_num = op->numInputs(); @@ -38,25 +37,26 @@ class MatmulCnnl : public BangKernelWithoutConfig { int32_t transB = op->getTransB(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError( - cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, - dimInputs0.size(), dimInputs0.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimInputs0.size(), dimInputs0.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError( - cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, - dimInputs1.size(), dimInputs1.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimInputs1.size(), dimInputs1.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError( - cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, - dimOutput.size(), dimOutput.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimOutput.size(), dimOutput.data())); if (input_num > 2) { checkCnnlError(cnnlCreateTensorDescriptor(&biasDesc)); - checkCnnlError(cnnlSetTensorDescriptor( - biasDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dimBias.size(), - dimBias.data())); + checkCnnlError( + cnnlSetTensorDescriptor(biasDesc, CNNL_LAYOUT_ARRAY, + cnnlDataTypeConvert(op->getDType()), + dimBias.size(), dimBias.data())); } cnnlMatMulDescriptor_t bmm_desc; diff --git a/src/kernels/bang/negtensor.cc b/src/kernels/bang/negtensor.cc index 12377610..170138ac 100644 --- a/src/kernels/bang/negtensor.cc +++ b/src/kernels/bang/negtensor.cc @@ -7,7 +7,6 @@ class NegTensorCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class NegTensorCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlNegTensor(context->cnnlHandle(), aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/pad.cc b/src/kernels/bang/pad.cc index e8aafa1a..c35b92e1 100644 --- a/src/kernels/bang/pad.cc +++ b/src/kernels/bang/pad.cc @@ -7,7 +7,6 @@ class PadCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -37,14 +36,14 @@ class PadCnnl : public BangKernelWithoutConfig { float paddingValue = 0.0; // input checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimIn.size(), - dimIn.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimIn.size(), dimIn.data())); // output checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimOut.size(), - dimOut.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimOut.size(), dimOut.data())); cnnlStatus_t stat = cnnlPad(context->cnnlHandle(), aDesc, aData, paddings, &paddingValue, cDesc, cData); diff --git a/src/kernels/bang/pooling.cc b/src/kernels/bang/pooling.cc index 90a0637f..95db2185 100644 --- a/src/kernels/bang/pooling.cc +++ b/src/kernels/bang/pooling.cc @@ -8,7 +8,6 @@ class PoolingCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const inData = (op->getInputs(0)->getRawDataPtr()); void *const outData = (op->getOutput()->getRawDataPtr()); @@ -20,8 +19,9 @@ class PoolingCnnl : public BangKernelWithoutConfig { int inArray[4] = {n, c, h, w}; cnnlTensorDescriptor_t inDesc; checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); - checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, inArray)); + checkCnnlError(cnnlSetTensorDescriptor( + inDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + inArray)); bool mode = op->getCeilMode(); // get maxpool descriptor @@ -37,8 +37,9 @@ class PoolingCnnl : public BangKernelWithoutConfig { int outArray[4] = {outVec[0], outVec[1], outVec[2], outVec[3]}; cnnlTensorDescriptor_t outDesc; checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); - checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, 4, outArray)); + checkCnnlError(cnnlSetTensorDescriptor( + outDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4, + outArray)); size_t wsSize; cnnlGetPoolingWorkspaceSize(context->cnnlHandle(), getPoolingMode(), outVec[3], outVec[2], &wsSize); diff --git a/src/kernels/bang/reciprocal.cc b/src/kernels/bang/reciprocal.cc index 7b61c2ca..1c95393a 100644 --- a/src/kernels/bang/reciprocal.cc +++ b/src/kernels/bang/reciprocal.cc @@ -7,7 +7,6 @@ class ReciprocalCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class ReciprocalCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlReciprocal(context->cnnlHandle(), aDesc, aData, cDesc, cData); diff --git a/src/kernels/bang/reduce.cc b/src/kernels/bang/reduce.cc index 810aca72..374a0d0a 100644 --- a/src/kernels/bang/reduce.cc +++ b/src/kernels/bang/reduce.cc @@ -9,7 +9,6 @@ class ReduceCnnlBase : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -26,20 +25,20 @@ class ReduceCnnlBase : public BangKernelWithoutConfig { cnnlTensorDescriptor_t inDesc, outDesc; checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); - checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); - checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, bDim.size(), - bDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + bDim.size(), bDim.data())); // get reduce descriptor cnnlReduceDescriptor_t reduceDesc; checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc)); checkCnnlError(cnnlSetReduceDescriptor_v2( reduceDesc, axes.data(), axes.size(), getReduceOp(), - CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, - CNNL_32BIT_INDICES, 0.0)); + cnnlDataTypeConvert(op->getDType()), CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0)); // get workspace size_t workspaceSize = 0; diff --git a/src/kernels/bang/rsqrt.cc b/src/kernels/bang/rsqrt.cc index 66e63e0a..f6ba8e12 100644 --- a/src/kernels/bang/rsqrt.cc +++ b/src/kernels/bang/rsqrt.cc @@ -7,7 +7,6 @@ class RsqrtCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class RsqrtCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlRsqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/slice.cc b/src/kernels/bang/slice.cc index 5cc772aa..303ce741 100644 --- a/src/kernels/bang/slice.cc +++ b/src/kernels/bang/slice.cc @@ -42,11 +42,13 @@ class SliceCnnl : public BangKernelWithoutConfig { // input checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, aDim_size, aDim_array)); + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + aDim_size, aDim_array)); // output checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor( - cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDim_size, cDim_array)); + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + cDim_size, cDim_array)); cnnlStatus_t stat = cnnlStridedSlice(context->cnnlHandle(), aDesc, aData, starts_array, @@ -59,6 +61,6 @@ class SliceCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl, +REGISTER_KERNEL(Device::BANG, OpType::Slice, SliceCnnl, "Slice_cnnl_BANG_Float32"); }; // namespace infini diff --git a/src/kernels/bang/split.cc b/src/kernels/bang/split.cc index 397b5063..d7342946 100644 --- a/src/kernels/bang/split.cc +++ b/src/kernels/bang/split.cc @@ -7,7 +7,6 @@ class SplitCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); int num = op->numOutputs(); int axis = op->getDim(); @@ -16,15 +15,17 @@ class SplitCnnl : public BangKernelWithoutConfig { cnnlTensorDescriptor_t desc; checkCnnlError(cnnlCreateTensorDescriptor(&desc)); checkCnnlError(cnnlSetTensorDescriptor( - desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data())); + desc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + dim.size(), dim.data())); cnnlTensorDescriptor_t descArray[num]; for (int i = 0; i < num; ++i) { checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i])); - checkCnnlError(cnnlSetTensorDescriptor( - descArray[i], CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, - op->getOutput(i)->getDims().size(), - op->getOutput(i)->getDims().data())); + checkCnnlError( + cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW, + cnnlDataTypeConvert(op->getDType()), + op->getOutput(i)->getDims().size(), + op->getOutput(i)->getDims().data())); } void *const inputData = (op->getInputs(0)->getRawDataPtr()); diff --git a/src/kernels/bang/sqrt.cc b/src/kernels/bang/sqrt.cc index a1ed85c9..59bc8dcb 100644 --- a/src/kernels/bang/sqrt.cc +++ b/src/kernels/bang/sqrt.cc @@ -7,7 +7,6 @@ class SqrtCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class SqrtCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlStatus_t stat = cnnlSqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, diff --git a/src/kernels/bang/transpose.cc b/src/kernels/bang/transpose.cc index 7dedd21d..c69fa455 100644 --- a/src/kernels/bang/transpose.cc +++ b/src/kernels/bang/transpose.cc @@ -7,7 +7,6 @@ class TransposeCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -18,13 +17,13 @@ class TransposeCnnl : public BangKernelWithoutConfig { auto dimout = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimin.size(), - dimin.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimin.size(), dimin.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dimout.size(), - dimout.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dimout.size(), dimout.data())); auto permute = op->getPermute(); cnnlTransposeDescriptor_t opDesc; @@ -53,7 +52,6 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -73,13 +71,13 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig { auto dimout = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, reshape.size(), - reshape.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + reshape.size(), reshape.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError( - cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, - transpose.size(), transpose.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + transpose.size(), transpose.data())); cnnlTransposeDescriptor_t opDesc; checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc)); diff --git a/src/kernels/bang/trigon.cc b/src/kernels/bang/trigon.cc index 989858c4..d6ce8273 100644 --- a/src/kernels/bang/trigon.cc +++ b/src/kernels/bang/trigon.cc @@ -9,7 +9,6 @@ class TrigonCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -20,13 +19,13 @@ class TrigonCnnl : public BangKernelWithoutConfig { auto cDim = op->getOutput()->getDims(); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); - checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, - CNNL_DTYPE_FLOAT, cDim.size(), - cDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); cnnlTrigonDescriptor_t opDesc; checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc)); diff --git a/src/kernels/bang/where.cc b/src/kernels/bang/where.cc index 8786f3fd..4ed26bd7 100644 --- a/src/kernels/bang/where.cc +++ b/src/kernels/bang/where.cc @@ -7,7 +7,6 @@ class WhereCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); @@ -35,21 +34,21 @@ class WhereCnnl : public BangKernelWithoutConfig { } checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, aDim.size(), - aDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, bDim.size(), - bDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + bDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + bDim.size(), bDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_BOOL, cDim.size(), cDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&dDesc)); - checkCnnlError(cnnlSetTensorDescriptor(dDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dDim.size(), - dDim.data())); + checkCnnlError(cnnlSetTensorDescriptor( + dDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dDim.size(), dDim.data())); size_t wsSize; cnnlGetSelectV2WorkspaceSize(context->cnnlHandle(), cDesc, aDesc, bDesc, &wsSize);