support more data type on mlu (#211)

* support more data type

* clang format

* fix little bug

* fix cncl datatype

* fix format

---------

Co-authored-by: wanghailu <wanghailu0717@163.com>
Co-authored-by: Zhang Bolun <Chamberlain0w0@gmail.com>
This commit is contained in:
Hardy 2024-01-24 13:33:33 +08:00 committed by GitHub
parent 51086d2b8d
commit 09b2ecf98a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 418 additions and 408 deletions

View File

@ -2,6 +2,7 @@
#include "cnnl.h"
#include "cnrt.h"
#include "core/common.h"
#include "core/data_type.h"
#define checkBangError(call) \
{ \
@ -27,4 +28,35 @@ namespace infini {
using BangPtr = void *;
inline cnnlDataType_t cnnlDataTypeConvert(DataType dataType) {
if (dataType == DataType::Float32) {
return CNNL_DTYPE_FLOAT;
}
if (dataType == DataType::Float16) {
return CNNL_DTYPE_HALF;
}
if (dataType == DataType::Double) {
return CNNL_DTYPE_DOUBLE;
}
if (dataType == DataType::Int8) {
return CNNL_DTYPE_INT8;
}
if (dataType == DataType::Int32) {
return CNNL_DTYPE_INT32;
}
if (dataType == DataType::UInt8) {
return CNNL_DTYPE_UINT8;
}
if (dataType == DataType::BFloat16) {
return CNNL_DTYPE_BFLOAT16;
}
if (dataType == DataType::Int64) {
return CNNL_DTYPE_INT64;
}
if (dataType == DataType::Bool) {
return CNNL_DTYPE_BOOL;
}
return CNNL_DTYPE_INVALID;
}
} // namespace infini

View File

@ -11,7 +11,6 @@ class UnaryCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -22,13 +21,13 @@ class UnaryCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlActivationDescriptor_t opDesc;
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
checkCnnlError(cnnlSetActivationDescriptor_v2(
@ -51,7 +50,6 @@ class RoundCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -62,13 +60,13 @@ class RoundCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlRound(context->cnnlHandle(), aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
@ -82,7 +80,6 @@ class PReluCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<PReluObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -95,17 +92,17 @@ class PReluCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, bDim.size(),
bDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
bDim.size(), bDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlPrelu(context->cnnlHandle(), aDesc, aData,
bDesc, bData, cDesc, cData);
@ -122,7 +119,6 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<SoftmaxObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -185,13 +181,13 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, inDim.size(),
inDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
inDim.size(), inDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, outDim.size(),
outDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
outDim.size(), outDim.data()));
float alpha = 1.0;
float beta = 0.0;
cnnlStatus_t stat =

View File

@ -10,7 +10,6 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ActivationBackwardObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const yData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -25,21 +24,21 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig {
auto diffxDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&yDesc));
checkCnnlError(cnnlSetTensorDescriptor(yDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, yDim.size(),
yDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
yDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
yDim.size(), yDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&diffYDesc));
checkCnnlError(cnnlSetTensorDescriptor(
diffYDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, diffyDim.size(),
diffyDim.data()));
diffYDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
diffyDim.size(), diffyDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&xDesc));
checkCnnlError(cnnlSetTensorDescriptor(xDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, xDim.size(),
xDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
xDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
xDim.size(), xDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&diffXDesc));
checkCnnlError(cnnlSetTensorDescriptor(
diffXDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, diffxDim.size(),
diffxDim.data()));
diffXDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
diffxDim.size(), diffxDim.data()));
// get op descriptor
cnnlActivationDescriptor_t opDesc;
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));

View File

@ -19,18 +19,17 @@ class AllGatherCNCL : public BangKernelWithoutConfig {
BangPtr output_temp =
context->getWorkspace(op->getInputs(0)->getBytes() * world_size);
// void *output = op->getOutput()->getRawDataPtr<void *>();
// IT_ASSERT(op->getDType() == DataType::Float32);
checkBangError(cnrtMalloc(&output_temp,
op->getInputs(0)->getBytes() * world_size));
size_t bytes = op->getInputs(0)->getBytes();
size_t count = bytes / op->getDType().getSize();
size_t count = bytes / sizeof(uint8_t);
cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm();
cnrtQueue_t queue = context->getBangQueue();
CNCL_CHECK(
cnclAllGather(input, output_temp, count, cnclFloat32, comm, queue));
cnclAllGather(input, output_temp, count, cnclUint8, comm, queue));
checkBangError(cnrtQueueSync(queue));
for (int i = 0; i < world_size; ++i) {
Tensor output = op->getOutput(i);
@ -42,8 +41,8 @@ class AllGatherCNCL : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::AllGather, DataType::Float32,
AllGatherCNCL, "AllGather_CNCL_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AllGather, AllGatherCNCL,
"AllGather_CNCL_BANG");
} // namespace infini
#endif

View File

@ -13,14 +13,14 @@ class AllReduceCNCL : public BangKernelWithoutConfig {
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32);
size_t count = op->getInputs(0)->size();
size_t bytes = op->getInputs(0)->getBytes();
size_t count = bytes / sizeof(uint8_t);
cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm();
cnrtQueue_t queue = context->getBangQueue();
// checkBangError(cnrtQueueSync(queue));
CNCL_CHECK(cnclAllReduce(input, output, count, cnclFloat, getRedOp(),
CNCL_CHECK(cnclAllReduce(input, output, count, cnclUint8, getRedOp(),
comm, queue));
checkBangError(cnrtQueueSync(queue));
}
@ -41,13 +41,13 @@ class AllReduceMaxCNCL : public AllReduceCNCL {
cnclReduceOp_t getRedOp() const override { return cnclMax; }
};
REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, DataType::Float32,
AllReduceSumCNCL, "AllReduce_Sum_CNCL_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, DataType::Float32,
AllReduceProdCNCL, "AllReduce_Prod_CNCL_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, DataType::Float32,
AllReduceMinCNCL, "AllReduce_Min_CNCL_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, DataType::Float32,
AllReduceMaxCNCL, "AllReduce_Max_CNCL_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, AllReduceSumCNCL,
"AllReduce_Sum_CNCL_BANG");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, AllReduceProdCNCL,
"AllReduce_Prod_CNCL_BANG");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, AllReduceMinCNCL,
"AllReduce_Min_CNCL_BANG");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, AllReduceMaxCNCL,
"AllReduce_Max_CNCL_BANG");
} // namespace infini
#endif

View File

@ -7,7 +7,6 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<BatchNormObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
@ -33,18 +32,18 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&intransDesc));
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
checkCnnlError(cnnlCreateTensorDescriptor(&outtransDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, dims.size(),
dims.data()));
checkCnnlError(cnnlSetTensorDescriptor(intransDesc, CNNL_LAYOUT_NHWC,
CNNL_DTYPE_FLOAT, dims.size(),
dimsTrans));
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, outDims.size(),
outDims.data()));
checkCnnlError(cnnlSetTensorDescriptor(outtransDesc, CNNL_LAYOUT_NHWC,
CNNL_DTYPE_FLOAT, outDims.size(),
dimsOutTrans));
checkCnnlError(cnnlSetTensorDescriptor(
inDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
dims.size(), dims.data()));
checkCnnlError(cnnlSetTensorDescriptor(
intransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
dims.size(), dimsTrans));
checkCnnlError(cnnlSetTensorDescriptor(
outDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
outDims.size(), outDims.data()));
checkCnnlError(cnnlSetTensorDescriptor(
outtransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
outDims.size(), dimsOutTrans));
cnnlTransposeDescriptor_t opDesc;
checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc));
checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute));
@ -53,9 +52,9 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
&wsSize);
BangPtr wsData = context->getWorkspace(wsSize);
BangPtr inputTrans = context->getWorkspace(
cnnlGetTensorElementNum(inDesc) * sizeof(float));
cnnlGetTensorElementNum(inDesc) * op->getDType().getSize());
BangPtr outputTrans = context->getWorkspace(
cnnlGetTensorElementNum(inDesc) * sizeof(float));
cnnlGetTensorElementNum(inDesc) * op->getDType().getSize());
cnnlStatus_t stat =
cnnlTranspose_v2(context->cnnlHandle(), opDesc, inDesc, input,
intransDesc, inputTrans, wsData, wsSize);
@ -67,7 +66,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
cnnlTensorDescriptor_t paraDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&paraDesc));
checkCnnlError(cnnlSetTensorDescriptor(
paraDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
paraDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimsScaleBiasMeanVar.size(), dimsScaleBiasMeanVar.data()));
float alpha = 1.f, beta = 0.f;

View File

@ -13,22 +13,22 @@ class BroadcastCNCL : public BangKernelWithoutConfig {
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32);
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
size_t bytes = op->getInputs(0)->getBytes();
size_t count = bytes / sizeof(uint8_t);
cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm();
cnrtQueue_t queue = context->getBangQueue();
// TODO: Using default stream 0 for now.
CNCL_CHECK(cnclBroadcast(input, output, count, cnclFloat32,
op->getRoot(), comm, queue));
CNCL_CHECK(cnclBroadcast(input, output, count, cnclUint8, op->getRoot(),
comm, queue));
checkBangError(cnrtQueueSync(queue));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Broadcast, DataType::Float32,
BroadcastCNCL, "Broadcast_CNCL_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Broadcast, BroadcastCNCL,
"Broadcast_CNCL_BANG");
} // namespace infini
#endif

View File

@ -7,7 +7,6 @@ class CeilCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class CeilCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlCeil(context->cnnlHandle(), aDesc, aData, cDesc, cData);

View File

@ -7,7 +7,6 @@ class ClipCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ClipObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -19,9 +18,9 @@ class ClipCnnl : public BangKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
cnnlStatus_t stat =
cnnlClip(context->cnnlHandle(), aDesc, aData, &min, &max, cData);
if (stat != CNNL_STATUS_SUCCESS)

View File

@ -7,7 +7,6 @@ class ConcatCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ConcatObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numInputs();
int axis = op->getDim();
@ -15,17 +14,18 @@ class ConcatCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
cnnlTensorDescriptor_t desc;
checkCnnlError(cnnlCreateTensorDescriptor(&desc));
checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
desc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlTensorDescriptor_t descArray[num];
for (int i = 0; i < num; ++i) {
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
checkCnnlError(cnnlSetTensorDescriptor(
descArray[i], CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT,
op->getInputs(i)->getDims().size(),
op->getInputs(i)->getDims().data()));
checkCnnlError(
cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
cnnlDataTypeConvert(op->getDType()),
op->getInputs(i)->getDims().size(),
op->getInputs(i)->getDims().data()));
}
void *argv[num];

View File

@ -7,7 +7,6 @@ class ConvCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ConvObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -21,8 +20,9 @@ class ConvCnnl : public BangKernelWithoutConfig {
cnnlConvolutionDescriptor_t convDesc;
checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc));
checkCnnlError(cnnlSetConvolutionDescriptor(
convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT));
checkCnnlError(
cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g,
cnnlDataTypeConvert(op->getDType())));
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
@ -55,20 +55,24 @@ class ConvCnnl : public BangKernelWithoutConfig {
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aInDesc));
checkCnnlError(cnnlSetTensorDescriptor(aInDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, inputs0));
checkCnnlError(cnnlSetTensorDescriptor(
aInDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inputs0));
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs0Array));
aDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4,
inputs0Array));
checkCnnlError(cnnlCreateTensorDescriptor(&bInDesc));
checkCnnlError(cnnlSetTensorDescriptor(bInDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, inputs1));
checkCnnlError(cnnlSetTensorDescriptor(
bInDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inputs1));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs1Array));
bDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4,
inputs1Array));
int permute[4] = {0, 2, 3, 1};
cnnlTransposeDescriptor_t opDesc;
@ -80,7 +84,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
&wsSize);
BangPtr wsData = context->getWorkspace(wsSize);
BangPtr aDataOut = context->getWorkspace(
cnnlGetTensorElementNum(aInDesc) * sizeof(float));
cnnlGetTensorElementNum(aInDesc) * op->getDType().getSize());
cnnlStatus_t stat =
cnnlTranspose_v2(context->cnnlHandle(), opDesc, aInDesc, aData,
aDesc, aDataOut, wsData, wsSize);
@ -91,7 +95,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
&wsSize);
wsData = context->getWorkspace(wsSize);
BangPtr bDataOut = context->getWorkspace(
cnnlGetTensorElementNum(bInDesc) * sizeof(float));
cnnlGetTensorElementNum(bInDesc) * op->getDType().getSize());
stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, bInDesc, bData,
bDesc, bDataOut, wsData, wsSize);
if (stat != CNNL_STATUS_SUCCESS)
@ -100,11 +104,13 @@ class ConvCnnl : public BangKernelWithoutConfig {
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cInDesc));
checkCnnlError(cnnlSetTensorDescriptor(
cInDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, outputArray));
cInDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4,
outputArray));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, output));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
output));
cnnlConvolutionForwardAlgo_t algo;
cnnlGetConvolutionForwardAlgorithm(context->cnnlHandle(), convDesc,
@ -116,7 +122,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
algo, &wsSize);
wsData = context->getWorkspace(wsSize);
BangPtr cDataIn = context->getWorkspace(
cnnlGetTensorElementNum(cInDesc) * sizeof(float));
cnnlGetTensorElementNum(cInDesc) * op->getDType().getSize());
stat = cnnlConvolutionForward(
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc,

View File

@ -7,7 +7,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ConvBaseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -21,8 +20,9 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
cnnlConvolutionDescriptor_t convDesc;
checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc));
checkCnnlError(cnnlSetConvolutionDescriptor(
convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT));
checkCnnlError(
cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g,
cnnlDataTypeConvert(op->getDType())));
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
@ -43,14 +43,17 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs0.data()));
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
dimInputs0.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs1.data()));
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
dimInputs1.data()));
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimOutput.data()));
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
dimOutput.data()));
cnnlConvolutionBwdDataAlgo_t algo;
cnnlGetConvolutionBackwardDataAlgorithm(

View File

@ -7,7 +7,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ConvBackwardFilterObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -21,8 +20,9 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
cnnlConvolutionDescriptor_t convDesc;
checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc));
checkCnnlError(cnnlSetConvolutionDescriptor(
convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT));
checkCnnlError(
cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g,
cnnlDataTypeConvert(op->getDType())));
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
@ -63,15 +63,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs0Array));
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inputs0Array));
checkCnnlError(cnnlCreateTensorDescriptor(&aDescTrans));
checkCnnlError(cnnlSetTensorDescriptor(aDescTrans, CNNL_LAYOUT_NHWC,
CNNL_DTYPE_FLOAT, 4,
inputs0ArrayTrans));
checkCnnlError(cnnlSetTensorDescriptor(
aDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
4, inputs0ArrayTrans));
size_t wsTrans1Size = dimInputs0[0] * dimInputs0[1] * dimInputs0[2] *
dimInputs0[3] * sizeof(float);
dimInputs0[3] * op->getDType().getSize();
BangPtr wsTrans1Data = context->getWorkspace(wsTrans1Size);
cnnlStatus_t stat =
@ -82,15 +83,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs1Array));
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inputs1Array));
checkCnnlError(cnnlCreateTensorDescriptor(&bDescTrans));
checkCnnlError(cnnlSetTensorDescriptor(bDescTrans, CNNL_LAYOUT_NHWC,
CNNL_DTYPE_FLOAT, 4,
inputs1ArrayTrans));
checkCnnlError(cnnlSetTensorDescriptor(
bDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
4, inputs1ArrayTrans));
size_t wsTrans2Size = dimInputs1[0] * dimInputs1[1] * dimInputs1[2] *
dimInputs1[3] * sizeof(float);
dimInputs1[3] * op->getDType().getSize();
BangPtr wsTrans2Data = context->getWorkspace(wsTrans2Size);
stat = cnnlTranspose(context->cnnlHandle(), transDesc, bDesc, bData,
@ -101,15 +103,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, outputArray));
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
outputArray));
checkCnnlError(cnnlCreateTensorDescriptor(&cDescTrans));
checkCnnlError(cnnlSetTensorDescriptor(cDescTrans, CNNL_LAYOUT_NHWC,
CNNL_DTYPE_FLOAT, 4,
outputArrayTrans));
checkCnnlError(cnnlSetTensorDescriptor(
cDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
4, outputArrayTrans));
size_t wsTrans3Size = dimOutput[0] * dimOutput[1] * dimOutput[2] *
dimOutput[3] * sizeof(float);
dimOutput[3] * op->getDType().getSize();
BangPtr wsTrans3Data = context->getWorkspace(wsTrans3Size);
cnnlConvolutionBwdFilterAlgo_t algo;

View File

@ -7,7 +7,6 @@ class DetCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<DetObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -24,14 +23,14 @@ class DetCnnl : public BangKernelWithoutConfig {
auto dimout = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dimin.size(),
dimin.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimin.size(), dimin.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dimout.size(),
dimout.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimout.size(), dimout.data()));
cnnlStatus_t stat =
cnnlDet(context->cnnlHandle(), nlMode, aDesc, aData, cDesc, cData);

View File

@ -11,7 +11,6 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -31,24 +30,25 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
cnnlOpTensorDescriptor_t opDesc;
checkCnnlError(cnnlCreateOpTensorDescriptor(&opDesc));
checkCnnlError(cnnlSetOpTensorDescriptor(
opDesc, getOpType(), CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN));
opDesc, getOpType(), cnnlDataTypeConvert(op->getDType()),
CNNL_NOT_PROPAGATE_NAN));
size_t wsSize;
cnnlGetOpTensorWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -75,7 +75,6 @@ class LogicOpCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -95,17 +94,17 @@ class LogicOpCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
size_t wsSize;
cnnlGetLogicOpWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -129,7 +128,6 @@ class BitComputeCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -182,7 +180,6 @@ class DivCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -202,17 +199,17 @@ class DivCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
size_t wsSize;
cnnlGetDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -235,7 +232,6 @@ class MaximumCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -255,17 +251,17 @@ class MaximumCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
size_t wsSize;
cnnlGetMaximumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize);
@ -287,7 +283,6 @@ class MinimumCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -307,17 +302,17 @@ class MinimumCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
size_t wsSize;
cnnlGetMinimumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize);
@ -339,7 +334,6 @@ class MSELossCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<MSELossObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -359,18 +353,18 @@ class MSELossCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
cnnlStatus_t stat;
if (reduction == MSELossObj::None) {
stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_NONE, aDesc,
@ -396,7 +390,6 @@ class PowerCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -417,17 +410,17 @@ class PowerCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
size_t wsSize;
cnnlGetPowWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -450,7 +443,6 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -470,17 +462,17 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
size_t wsSize;
cnnlGetFloorDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -503,7 +495,6 @@ class FloorModCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -523,17 +514,17 @@ class FloorModCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
size_t wsSize;
cnnlGetFloorModWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -556,7 +547,6 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -576,17 +566,17 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(),
a_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, b_dim.size(),
b_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, c_dim.size(),
c_dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.size(), c_dim.data()));
size_t wsSize;
cnnlGetSquaredDifferenceWorkspaceSize(context->cnnlHandle(), aDesc,

View File

@ -7,7 +7,6 @@ class ErfCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class ErfCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlErf_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -7,7 +7,6 @@ class ExpCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class ExpCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlExp_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -7,7 +7,6 @@ class FillCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<FillObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -17,9 +16,9 @@ class FillCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlFill(context->cnnlHandle(), value, cDesc, cData);

View File

@ -7,7 +7,6 @@ class FloorCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class FloorCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlFloor(context->cnnlHandle(), aDesc, aData, cDesc, cData);

View File

@ -7,7 +7,6 @@ class GatherCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<GatherObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -20,9 +19,9 @@ class GatherCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(
cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST));
@ -30,9 +29,9 @@ class GatherCnnl : public BangKernelWithoutConfig {
CNNL_DTYPE_INT32, bDim.size(),
bDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
BangPtr wsData = context->getWorkspace(aDim.size() * 4);
context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4);

View File

@ -7,7 +7,6 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<HardtanhObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -20,7 +19,8 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data()));
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
dim.size(), dim.data()));
cnnlStatus_t stat = cnnlHardtanh(context->cnnlHandle(), aDesc, aData,
max, min, aDesc, cData);

View File

@ -7,7 +7,6 @@ class L2LossCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<L2LossObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,7 +17,8 @@ class L2LossCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data()));
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
dim.size(), dim.data()));
cnnlStatus_t stat =
cnnlL2Loss(context->cnnlHandle(), aDesc, aData, cData);

View File

@ -8,7 +8,6 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LayerNormObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -29,17 +28,17 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, inDims.size(),
inDims.data()));
checkCnnlError(cnnlSetTensorDescriptor(
inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
inDims.size(), inDims.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&fiterDesc));
checkCnnlError(cnnlSetTensorDescriptor(
fiterDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, fiterDims.size(),
fiterDims.data()));
fiterDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
fiterDims.size(), fiterDims.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, outDims.size(),
outDims.data()));
checkCnnlError(cnnlSetTensorDescriptor(
outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
outDims.size(), outDims.data()));
size_t wsSize;
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
&wsSize);

View File

@ -7,7 +7,6 @@ class LogCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LogObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -33,13 +32,13 @@ class LogCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlLog_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -7,7 +7,6 @@ class LRNCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LRNObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -20,13 +19,13 @@ class LRNCnnl : public BangKernelWithoutConfig {
auto size = op->getSize();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
size_t extra_size;
cnnlGetLrnExtraInputSize_v2(context->cnnlHandle(), cDesc,

View File

@ -8,7 +8,6 @@ class MatmulCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<MatmulObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
auto input_num = op->numInputs();
@ -38,25 +37,26 @@ class MatmulCnnl : public BangKernelWithoutConfig {
int32_t transB = op->getTransB();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(
cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
dimInputs0.size(), dimInputs0.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimInputs0.size(), dimInputs0.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(
cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
dimInputs1.size(), dimInputs1.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimInputs1.size(), dimInputs1.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(
cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
dimOutput.size(), dimOutput.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimOutput.size(), dimOutput.data()));
if (input_num > 2) {
checkCnnlError(cnnlCreateTensorDescriptor(&biasDesc));
checkCnnlError(cnnlSetTensorDescriptor(
biasDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dimBias.size(),
dimBias.data()));
checkCnnlError(
cnnlSetTensorDescriptor(biasDesc, CNNL_LAYOUT_ARRAY,
cnnlDataTypeConvert(op->getDType()),
dimBias.size(), dimBias.data()));
}
cnnlMatMulDescriptor_t bmm_desc;

View File

@ -7,7 +7,6 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlNegTensor(context->cnnlHandle(), aDesc, aData, cDesc, cData);

View File

@ -7,7 +7,6 @@ class PadCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<PadObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -37,14 +36,14 @@ class PadCnnl : public BangKernelWithoutConfig {
float paddingValue = 0.0;
// input
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dimIn.size(),
dimIn.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimIn.size(), dimIn.data()));
// output
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dimOut.size(),
dimOut.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimOut.size(), dimOut.data()));
cnnlStatus_t stat = cnnlPad(context->cnnlHandle(), aDesc, aData,
paddings, &paddingValue, cDesc, cData);

View File

@ -8,7 +8,6 @@ class PoolingCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<PoolingObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -20,8 +19,9 @@ class PoolingCnnl : public BangKernelWithoutConfig {
int inArray[4] = {n, c, h, w};
cnnlTensorDescriptor_t inDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, inArray));
checkCnnlError(cnnlSetTensorDescriptor(
inDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inArray));
bool mode = op->getCeilMode();
// get maxpool descriptor
@ -37,8 +37,9 @@ class PoolingCnnl : public BangKernelWithoutConfig {
int outArray[4] = {outVec[0], outVec[1], outVec[2], outVec[3]};
cnnlTensorDescriptor_t outDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, outArray));
checkCnnlError(cnnlSetTensorDescriptor(
outDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
outArray));
size_t wsSize;
cnnlGetPoolingWorkspaceSize(context->cnnlHandle(), getPoolingMode(),
outVec[3], outVec[2], &wsSize);

View File

@ -7,7 +7,6 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlReciprocal(context->cnnlHandle(), aDesc, aData, cDesc, cData);

View File

@ -9,7 +9,6 @@ class ReduceCnnlBase : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ReduceBaseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -26,20 +25,20 @@ class ReduceCnnlBase : public BangKernelWithoutConfig {
cnnlTensorDescriptor_t inDesc, outDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, bDim.size(),
bDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
bDim.size(), bDim.data()));
// get reduce descriptor
cnnlReduceDescriptor_t reduceDesc;
checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc));
checkCnnlError(cnnlSetReduceDescriptor_v2(
reduceDesc, axes.data(), axes.size(), getReduceOp(),
CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES,
CNNL_32BIT_INDICES, 0.0));
cnnlDataTypeConvert(op->getDType()), CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0));
// get workspace
size_t workspaceSize = 0;

View File

@ -7,7 +7,6 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlRsqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -42,11 +42,13 @@ class SliceCnnl : public BangKernelWithoutConfig {
// input
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, aDim_size, aDim_array));
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
aDim_size, aDim_array));
// output
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDim_size, cDim_array));
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
cDim_size, cDim_array));
cnnlStatus_t stat =
cnnlStridedSlice(context->cnnlHandle(), aDesc, aData, starts_array,
@ -59,6 +61,6 @@ class SliceCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Slice, SliceCnnl,
"Slice_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -7,7 +7,6 @@ class SplitCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<SplitObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numOutputs();
int axis = op->getDim();
@ -16,15 +15,17 @@ class SplitCnnl : public BangKernelWithoutConfig {
cnnlTensorDescriptor_t desc;
checkCnnlError(cnnlCreateTensorDescriptor(&desc));
checkCnnlError(cnnlSetTensorDescriptor(
desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data()));
desc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
dim.size(), dim.data()));
cnnlTensorDescriptor_t descArray[num];
for (int i = 0; i < num; ++i) {
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
checkCnnlError(cnnlSetTensorDescriptor(
descArray[i], CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT,
op->getOutput(i)->getDims().size(),
op->getOutput(i)->getDims().data()));
checkCnnlError(
cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
cnnlDataTypeConvert(op->getDType()),
op->getOutput(i)->getDims().size(),
op->getOutput(i)->getDims().data()));
}
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());

View File

@ -7,7 +7,6 @@ class SqrtCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class SqrtCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlStatus_t stat =
cnnlSqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -7,7 +7,6 @@ class TransposeCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<TransposeObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class TransposeCnnl : public BangKernelWithoutConfig {
auto dimout = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dimin.size(),
dimin.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimin.size(), dimin.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dimout.size(),
dimout.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimout.size(), dimout.data()));
auto permute = op->getPermute();
cnnlTransposeDescriptor_t opDesc;
@ -53,7 +52,6 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<DepthToSpaceObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -73,13 +71,13 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
auto dimout = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, reshape.size(),
reshape.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
reshape.size(), reshape.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(
cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
transpose.size(), transpose.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
transpose.size(), transpose.data()));
cnnlTransposeDescriptor_t opDesc;
checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc));

View File

@ -9,7 +9,6 @@ class TrigonCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -20,13 +19,13 @@ class TrigonCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlTrigonDescriptor_t opDesc;
checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc));

View File

@ -7,7 +7,6 @@ class WhereCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<WhereObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,21 +34,21 @@ class WhereCnnl : public BangKernelWithoutConfig {
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, bDim.size(),
bDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
bDim.size(), bDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_BOOL, cDim.size(),
cDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&dDesc));
checkCnnlError(cnnlSetTensorDescriptor(dDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dDim.size(),
dDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
dDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dDim.size(), dDim.data()));
size_t wsSize;
cnnlGetSelectV2WorkspaceSize(context->cnnlHandle(), cDesc, aDesc, bDesc,
&wsSize);