support more data type on mlu (#211)

* support more data type

* clang format

* fix little bug

* fix cncl datatype

* fix format

---------

Co-authored-by: wanghailu <wanghailu0717@163.com>
Co-authored-by: Zhang Bolun <Chamberlain0w0@gmail.com>
This commit is contained in:
Hardy 2024-01-24 13:33:33 +08:00 committed by GitHub
parent 51086d2b8d
commit 09b2ecf98a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 418 additions and 408 deletions

View File

@ -2,6 +2,7 @@
#include "cnnl.h" #include "cnnl.h"
#include "cnrt.h" #include "cnrt.h"
#include "core/common.h" #include "core/common.h"
#include "core/data_type.h"
#define checkBangError(call) \ #define checkBangError(call) \
{ \ { \
@ -27,4 +28,35 @@ namespace infini {
using BangPtr = void *; using BangPtr = void *;
inline cnnlDataType_t cnnlDataTypeConvert(DataType dataType) {
if (dataType == DataType::Float32) {
return CNNL_DTYPE_FLOAT;
}
if (dataType == DataType::Float16) {
return CNNL_DTYPE_HALF;
}
if (dataType == DataType::Double) {
return CNNL_DTYPE_DOUBLE;
}
if (dataType == DataType::Int8) {
return CNNL_DTYPE_INT8;
}
if (dataType == DataType::Int32) {
return CNNL_DTYPE_INT32;
}
if (dataType == DataType::UInt8) {
return CNNL_DTYPE_UINT8;
}
if (dataType == DataType::BFloat16) {
return CNNL_DTYPE_BFLOAT16;
}
if (dataType == DataType::Int64) {
return CNNL_DTYPE_INT64;
}
if (dataType == DataType::Bool) {
return CNNL_DTYPE_BOOL;
}
return CNNL_DTYPE_INVALID;
}
} // namespace infini } // namespace infini

View File

@ -11,7 +11,6 @@ class UnaryCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -22,13 +21,13 @@ class UnaryCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlActivationDescriptor_t opDesc; cnnlActivationDescriptor_t opDesc;
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
checkCnnlError(cnnlSetActivationDescriptor_v2( checkCnnlError(cnnlSetActivationDescriptor_v2(
@ -51,7 +50,6 @@ class RoundCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -62,13 +60,13 @@ class RoundCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlRound(context->cnnlHandle(), aDesc, aData, cDesc, cData); cnnlRound(context->cnnlHandle(), aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
@ -82,7 +80,6 @@ class PReluCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<PReluObj>(_op); auto op = as<PReluObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -95,17 +92,17 @@ class PReluCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, bDim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
bDim.data())); bDim.size(), bDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlPrelu(context->cnnlHandle(), aDesc, aData, cnnlStatus_t stat = cnnlPrelu(context->cnnlHandle(), aDesc, aData,
bDesc, bData, cDesc, cData); bDesc, bData, cDesc, cData);
@ -122,7 +119,6 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<SoftmaxObj>(_op); auto op = as<SoftmaxObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -185,13 +181,13 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, inDim.size(), aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
inDim.data())); inDim.size(), inDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, outDim.size(), cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
outDim.data())); outDim.size(), outDim.data()));
float alpha = 1.0; float alpha = 1.0;
float beta = 0.0; float beta = 0.0;
cnnlStatus_t stat = cnnlStatus_t stat =

View File

@ -10,7 +10,6 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ActivationBackwardObj>(_op); auto op = as<ActivationBackwardObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const yData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const yData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -25,21 +24,21 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig {
auto diffxDim = op->getOutput()->getDims(); auto diffxDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&yDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&yDesc));
checkCnnlError(cnnlSetTensorDescriptor(yDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, yDim.size(), yDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
yDim.data())); yDim.size(), yDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&diffYDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&diffYDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
diffYDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, diffyDim.size(), diffYDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
diffyDim.data())); diffyDim.size(), diffyDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&xDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&xDesc));
checkCnnlError(cnnlSetTensorDescriptor(xDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, xDim.size(), xDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
xDim.data())); xDim.size(), xDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&diffXDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&diffXDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
diffXDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, diffxDim.size(), diffXDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
diffxDim.data())); diffxDim.size(), diffxDim.data()));
// get op descriptor // get op descriptor
cnnlActivationDescriptor_t opDesc; cnnlActivationDescriptor_t opDesc;
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));

View File

@ -19,18 +19,17 @@ class AllGatherCNCL : public BangKernelWithoutConfig {
BangPtr output_temp = BangPtr output_temp =
context->getWorkspace(op->getInputs(0)->getBytes() * world_size); context->getWorkspace(op->getInputs(0)->getBytes() * world_size);
// void *output = op->getOutput()->getRawDataPtr<void *>(); // void *output = op->getOutput()->getRawDataPtr<void *>();
// IT_ASSERT(op->getDType() == DataType::Float32);
checkBangError(cnrtMalloc(&output_temp, checkBangError(cnrtMalloc(&output_temp,
op->getInputs(0)->getBytes() * world_size)); op->getInputs(0)->getBytes() * world_size));
size_t bytes = op->getInputs(0)->getBytes(); size_t bytes = op->getInputs(0)->getBytes();
size_t count = bytes / op->getDType().getSize(); size_t count = bytes / sizeof(uint8_t);
cnclComm_t comm = cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator()) dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm(); .getCnclComm();
cnrtQueue_t queue = context->getBangQueue(); cnrtQueue_t queue = context->getBangQueue();
CNCL_CHECK( CNCL_CHECK(
cnclAllGather(input, output_temp, count, cnclFloat32, comm, queue)); cnclAllGather(input, output_temp, count, cnclUint8, comm, queue));
checkBangError(cnrtQueueSync(queue)); checkBangError(cnrtQueueSync(queue));
for (int i = 0; i < world_size; ++i) { for (int i = 0; i < world_size; ++i) {
Tensor output = op->getOutput(i); Tensor output = op->getOutput(i);
@ -42,8 +41,8 @@ class AllGatherCNCL : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::AllGather, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::AllGather, AllGatherCNCL,
AllGatherCNCL, "AllGather_CNCL_BANG_Float32"); "AllGather_CNCL_BANG");
} // namespace infini } // namespace infini
#endif #endif

View File

@ -13,14 +13,14 @@ class AllReduceCNCL : public BangKernelWithoutConfig {
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>(); void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>(); void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32); size_t bytes = op->getInputs(0)->getBytes();
size_t count = op->getInputs(0)->size(); size_t count = bytes / sizeof(uint8_t);
cnclComm_t comm = cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator()) dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm(); .getCnclComm();
cnrtQueue_t queue = context->getBangQueue(); cnrtQueue_t queue = context->getBangQueue();
// checkBangError(cnrtQueueSync(queue)); // checkBangError(cnrtQueueSync(queue));
CNCL_CHECK(cnclAllReduce(input, output, count, cnclFloat, getRedOp(), CNCL_CHECK(cnclAllReduce(input, output, count, cnclUint8, getRedOp(),
comm, queue)); comm, queue));
checkBangError(cnrtQueueSync(queue)); checkBangError(cnrtQueueSync(queue));
} }
@ -41,13 +41,13 @@ class AllReduceMaxCNCL : public AllReduceCNCL {
cnclReduceOp_t getRedOp() const override { return cnclMax; } cnclReduceOp_t getRedOp() const override { return cnclMax; }
}; };
REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, AllReduceSumCNCL,
AllReduceSumCNCL, "AllReduce_Sum_CNCL_BANG_Float32"); "AllReduce_Sum_CNCL_BANG");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, AllReduceProdCNCL,
AllReduceProdCNCL, "AllReduce_Prod_CNCL_BANG_Float32"); "AllReduce_Prod_CNCL_BANG");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, AllReduceMinCNCL,
AllReduceMinCNCL, "AllReduce_Min_CNCL_BANG_Float32"); "AllReduce_Min_CNCL_BANG");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, AllReduceMaxCNCL,
AllReduceMaxCNCL, "AllReduce_Max_CNCL_BANG_Float32"); "AllReduce_Max_CNCL_BANG");
} // namespace infini } // namespace infini
#endif #endif

View File

@ -7,7 +7,6 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<BatchNormObj>(_op); auto op = as<BatchNormObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const input = (op->getInputs(0)->getRawDataPtr<void *>()); void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
@ -33,18 +32,18 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&intransDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&intransDesc));
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
checkCnnlError(cnnlCreateTensorDescriptor(&outtransDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outtransDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dims.size(), inDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
dims.data())); dims.size(), dims.data()));
checkCnnlError(cnnlSetTensorDescriptor(intransDesc, CNNL_LAYOUT_NHWC, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dims.size(), intransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
dimsTrans)); dims.size(), dimsTrans));
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, outDims.size(), outDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
outDims.data())); outDims.size(), outDims.data()));
checkCnnlError(cnnlSetTensorDescriptor(outtransDesc, CNNL_LAYOUT_NHWC, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, outDims.size(), outtransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
dimsOutTrans)); outDims.size(), dimsOutTrans));
cnnlTransposeDescriptor_t opDesc; cnnlTransposeDescriptor_t opDesc;
checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc)); checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc));
checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute)); checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute));
@ -53,9 +52,9 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
&wsSize); &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
BangPtr inputTrans = context->getWorkspace( BangPtr inputTrans = context->getWorkspace(
cnnlGetTensorElementNum(inDesc) * sizeof(float)); cnnlGetTensorElementNum(inDesc) * op->getDType().getSize());
BangPtr outputTrans = context->getWorkspace( BangPtr outputTrans = context->getWorkspace(
cnnlGetTensorElementNum(inDesc) * sizeof(float)); cnnlGetTensorElementNum(inDesc) * op->getDType().getSize());
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlTranspose_v2(context->cnnlHandle(), opDesc, inDesc, input, cnnlTranspose_v2(context->cnnlHandle(), opDesc, inDesc, input,
intransDesc, inputTrans, wsData, wsSize); intransDesc, inputTrans, wsData, wsSize);
@ -67,7 +66,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
cnnlTensorDescriptor_t paraDesc; cnnlTensorDescriptor_t paraDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&paraDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&paraDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
paraDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, paraDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimsScaleBiasMeanVar.size(), dimsScaleBiasMeanVar.data())); dimsScaleBiasMeanVar.size(), dimsScaleBiasMeanVar.data()));
float alpha = 1.f, beta = 0.f; float alpha = 1.f, beta = 0.f;

View File

@ -13,22 +13,22 @@ class BroadcastCNCL : public BangKernelWithoutConfig {
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>(); void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>(); void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32); size_t bytes = op->getInputs(0)->getBytes();
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); size_t count = bytes / sizeof(uint8_t);
cnclComm_t comm = cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator()) dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm(); .getCnclComm();
cnrtQueue_t queue = context->getBangQueue(); cnrtQueue_t queue = context->getBangQueue();
// TODO: Using default stream 0 for now. // TODO: Using default stream 0 for now.
CNCL_CHECK(cnclBroadcast(input, output, count, cnclFloat32, CNCL_CHECK(cnclBroadcast(input, output, count, cnclUint8, op->getRoot(),
op->getRoot(), comm, queue)); comm, queue));
checkBangError(cnrtQueueSync(queue)); checkBangError(cnrtQueueSync(queue));
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Broadcast, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::Broadcast, BroadcastCNCL,
BroadcastCNCL, "Broadcast_CNCL_BANG_Float32"); "Broadcast_CNCL_BANG");
} // namespace infini } // namespace infini
#endif #endif

View File

@ -7,7 +7,6 @@ class CeilCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class CeilCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlCeil(context->cnnlHandle(), aDesc, aData, cDesc, cData); cnnlCeil(context->cnnlHandle(), aDesc, aData, cDesc, cData);

View File

@ -7,7 +7,6 @@ class ClipCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ClipObj>(_op); auto op = as<ClipObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -19,9 +18,9 @@ class ClipCnnl : public BangKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlClip(context->cnnlHandle(), aDesc, aData, &min, &max, cData); cnnlClip(context->cnnlHandle(), aDesc, aData, &min, &max, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)

View File

@ -7,7 +7,6 @@ class ConcatCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ConcatObj>(_op); auto op = as<ConcatObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numInputs(); int num = op->numInputs();
int axis = op->getDim(); int axis = op->getDim();
@ -15,15 +14,16 @@ class ConcatCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
cnnlTensorDescriptor_t desc; cnnlTensorDescriptor_t desc;
checkCnnlError(cnnlCreateTensorDescriptor(&desc)); checkCnnlError(cnnlCreateTensorDescriptor(&desc));
checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), desc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlTensorDescriptor_t descArray[num]; cnnlTensorDescriptor_t descArray[num];
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i])); checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(
descArray[i], CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
cnnlDataTypeConvert(op->getDType()),
op->getInputs(i)->getDims().size(), op->getInputs(i)->getDims().size(),
op->getInputs(i)->getDims().data())); op->getInputs(i)->getDims().data()));
} }

View File

@ -7,7 +7,6 @@ class ConvCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ConvObj>(_op); auto op = as<ConvObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -21,8 +20,9 @@ class ConvCnnl : public BangKernelWithoutConfig {
cnnlConvolutionDescriptor_t convDesc; cnnlConvolutionDescriptor_t convDesc;
checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc)); checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc));
checkCnnlError(cnnlSetConvolutionDescriptor( checkCnnlError(
convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT)); cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g,
cnnlDataTypeConvert(op->getDType())));
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
@ -55,20 +55,24 @@ class ConvCnnl : public BangKernelWithoutConfig {
// get inputs // get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aInDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aInDesc));
checkCnnlError(cnnlSetTensorDescriptor(aInDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, inputs0)); aInDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inputs0));
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs0Array)); aDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4,
inputs0Array));
checkCnnlError(cnnlCreateTensorDescriptor(&bInDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bInDesc));
checkCnnlError(cnnlSetTensorDescriptor(bInDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, inputs1)); bInDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inputs1));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs1Array)); bDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4,
inputs1Array));
int permute[4] = {0, 2, 3, 1}; int permute[4] = {0, 2, 3, 1};
cnnlTransposeDescriptor_t opDesc; cnnlTransposeDescriptor_t opDesc;
@ -80,7 +84,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
&wsSize); &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
BangPtr aDataOut = context->getWorkspace( BangPtr aDataOut = context->getWorkspace(
cnnlGetTensorElementNum(aInDesc) * sizeof(float)); cnnlGetTensorElementNum(aInDesc) * op->getDType().getSize());
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlTranspose_v2(context->cnnlHandle(), opDesc, aInDesc, aData, cnnlTranspose_v2(context->cnnlHandle(), opDesc, aInDesc, aData,
aDesc, aDataOut, wsData, wsSize); aDesc, aDataOut, wsData, wsSize);
@ -91,7 +95,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
&wsSize); &wsSize);
wsData = context->getWorkspace(wsSize); wsData = context->getWorkspace(wsSize);
BangPtr bDataOut = context->getWorkspace( BangPtr bDataOut = context->getWorkspace(
cnnlGetTensorElementNum(bInDesc) * sizeof(float)); cnnlGetTensorElementNum(bInDesc) * op->getDType().getSize());
stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, bInDesc, bData, stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, bInDesc, bData,
bDesc, bDataOut, wsData, wsSize); bDesc, bDataOut, wsData, wsSize);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
@ -100,11 +104,13 @@ class ConvCnnl : public BangKernelWithoutConfig {
// get outputs // get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cInDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cInDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
cInDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, outputArray)); cInDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()), 4,
outputArray));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, output)); cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
output));
cnnlConvolutionForwardAlgo_t algo; cnnlConvolutionForwardAlgo_t algo;
cnnlGetConvolutionForwardAlgorithm(context->cnnlHandle(), convDesc, cnnlGetConvolutionForwardAlgorithm(context->cnnlHandle(), convDesc,
@ -116,7 +122,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
algo, &wsSize); algo, &wsSize);
wsData = context->getWorkspace(wsSize); wsData = context->getWorkspace(wsSize);
BangPtr cDataIn = context->getWorkspace( BangPtr cDataIn = context->getWorkspace(
cnnlGetTensorElementNum(cInDesc) * sizeof(float)); cnnlGetTensorElementNum(cInDesc) * op->getDType().getSize());
stat = cnnlConvolutionForward( stat = cnnlConvolutionForward(
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc, context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc,

View File

@ -7,7 +7,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ConvBaseObj>(_op); auto op = as<ConvBaseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -21,8 +20,9 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
cnnlConvolutionDescriptor_t convDesc; cnnlConvolutionDescriptor_t convDesc;
checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc)); checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc));
checkCnnlError(cnnlSetConvolutionDescriptor( checkCnnlError(
convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT)); cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g,
cnnlDataTypeConvert(op->getDType())));
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
@ -43,14 +43,17 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
// get inputs // get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs0.data())); aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
dimInputs0.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs1.data())); bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
dimInputs1.data()));
// get outputs // get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimOutput.data())); cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
dimOutput.data()));
cnnlConvolutionBwdDataAlgo_t algo; cnnlConvolutionBwdDataAlgo_t algo;
cnnlGetConvolutionBackwardDataAlgorithm( cnnlGetConvolutionBackwardDataAlgorithm(

View File

@ -7,7 +7,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ConvBackwardFilterObj>(_op); auto op = as<ConvBackwardFilterObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -21,8 +20,9 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
cnnlConvolutionDescriptor_t convDesc; cnnlConvolutionDescriptor_t convDesc;
checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc)); checkCnnlError(cnnlCreateConvolutionDescriptor(&convDesc));
checkCnnlError(cnnlSetConvolutionDescriptor( checkCnnlError(
convDesc, 4, pad, stride, dilation, g, CNNL_DTYPE_FLOAT)); cnnlSetConvolutionDescriptor(convDesc, 4, pad, stride, dilation, g,
cnnlDataTypeConvert(op->getDType())));
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
@ -63,15 +63,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
// get inputs // get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs0Array)); aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inputs0Array));
checkCnnlError(cnnlCreateTensorDescriptor(&aDescTrans)); checkCnnlError(cnnlCreateTensorDescriptor(&aDescTrans));
checkCnnlError(cnnlSetTensorDescriptor(aDescTrans, CNNL_LAYOUT_NHWC, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, aDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
inputs0ArrayTrans)); 4, inputs0ArrayTrans));
size_t wsTrans1Size = dimInputs0[0] * dimInputs0[1] * dimInputs0[2] * size_t wsTrans1Size = dimInputs0[0] * dimInputs0[1] * dimInputs0[2] *
dimInputs0[3] * sizeof(float); dimInputs0[3] * op->getDType().getSize();
BangPtr wsTrans1Data = context->getWorkspace(wsTrans1Size); BangPtr wsTrans1Data = context->getWorkspace(wsTrans1Size);
cnnlStatus_t stat = cnnlStatus_t stat =
@ -82,15 +83,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs1Array)); bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inputs1Array));
checkCnnlError(cnnlCreateTensorDescriptor(&bDescTrans)); checkCnnlError(cnnlCreateTensorDescriptor(&bDescTrans));
checkCnnlError(cnnlSetTensorDescriptor(bDescTrans, CNNL_LAYOUT_NHWC, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, bDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
inputs1ArrayTrans)); 4, inputs1ArrayTrans));
size_t wsTrans2Size = dimInputs1[0] * dimInputs1[1] * dimInputs1[2] * size_t wsTrans2Size = dimInputs1[0] * dimInputs1[1] * dimInputs1[2] *
dimInputs1[3] * sizeof(float); dimInputs1[3] * op->getDType().getSize();
BangPtr wsTrans2Data = context->getWorkspace(wsTrans2Size); BangPtr wsTrans2Data = context->getWorkspace(wsTrans2Size);
stat = cnnlTranspose(context->cnnlHandle(), transDesc, bDesc, bData, stat = cnnlTranspose(context->cnnlHandle(), transDesc, bDesc, bData,
@ -101,15 +103,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
// get outputs // get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, outputArray)); cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
outputArray));
checkCnnlError(cnnlCreateTensorDescriptor(&cDescTrans)); checkCnnlError(cnnlCreateTensorDescriptor(&cDescTrans));
checkCnnlError(cnnlSetTensorDescriptor(cDescTrans, CNNL_LAYOUT_NHWC, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, cDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
outputArrayTrans)); 4, outputArrayTrans));
size_t wsTrans3Size = dimOutput[0] * dimOutput[1] * dimOutput[2] * size_t wsTrans3Size = dimOutput[0] * dimOutput[1] * dimOutput[2] *
dimOutput[3] * sizeof(float); dimOutput[3] * op->getDType().getSize();
BangPtr wsTrans3Data = context->getWorkspace(wsTrans3Size); BangPtr wsTrans3Data = context->getWorkspace(wsTrans3Size);
cnnlConvolutionBwdFilterAlgo_t algo; cnnlConvolutionBwdFilterAlgo_t algo;

View File

@ -7,7 +7,6 @@ class DetCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<DetObj>(_op); auto op = as<DetObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -24,14 +23,14 @@ class DetCnnl : public BangKernelWithoutConfig {
auto dimout = op->getOutput()->getDims(); auto dimout = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dimin.size(), aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimin.data())); dimin.size(), dimin.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dimout.size(), cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimout.data())); dimout.size(), dimout.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlDet(context->cnnlHandle(), nlMode, aDesc, aData, cDesc, cData); cnnlDet(context->cnnlHandle(), nlMode, aDesc, aData, cDesc, cData);

View File

@ -11,7 +11,6 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -31,24 +30,25 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
cnnlOpTensorDescriptor_t opDesc; cnnlOpTensorDescriptor_t opDesc;
checkCnnlError(cnnlCreateOpTensorDescriptor(&opDesc)); checkCnnlError(cnnlCreateOpTensorDescriptor(&opDesc));
checkCnnlError(cnnlSetOpTensorDescriptor( checkCnnlError(cnnlSetOpTensorDescriptor(
opDesc, getOpType(), CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN)); opDesc, getOpType(), cnnlDataTypeConvert(op->getDType()),
CNNL_NOT_PROPAGATE_NAN));
size_t wsSize; size_t wsSize;
cnnlGetOpTensorWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, cnnlGetOpTensorWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -75,7 +75,6 @@ class LogicOpCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -95,17 +94,17 @@ class LogicOpCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
size_t wsSize; size_t wsSize;
cnnlGetLogicOpWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, cnnlGetLogicOpWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -129,7 +128,6 @@ class BitComputeCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -182,7 +180,6 @@ class DivCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -202,17 +199,17 @@ class DivCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
size_t wsSize; size_t wsSize;
cnnlGetDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, cnnlGetDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -235,7 +232,6 @@ class MaximumCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -255,17 +251,17 @@ class MaximumCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
size_t wsSize; size_t wsSize;
cnnlGetMaximumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize); cnnlGetMaximumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize);
@ -287,7 +283,6 @@ class MinimumCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -307,17 +302,17 @@ class MinimumCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
size_t wsSize; size_t wsSize;
cnnlGetMinimumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize); cnnlGetMinimumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize);
@ -339,7 +334,6 @@ class MSELossCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<MSELossObj>(_op); auto op = as<MSELossObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -359,18 +353,18 @@ class MSELossCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
cnnlStatus_t stat; cnnlStatus_t stat;
if (reduction == MSELossObj::None) { if (reduction == MSELossObj::None) {
stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_NONE, aDesc, stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_NONE, aDesc,
@ -396,7 +390,6 @@ class PowerCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -417,17 +410,17 @@ class PowerCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
size_t wsSize; size_t wsSize;
cnnlGetPowWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, cnnlGetPowWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -450,7 +443,6 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -470,17 +462,17 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
size_t wsSize; size_t wsSize;
cnnlGetFloorDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, cnnlGetFloorDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -503,7 +495,6 @@ class FloorModCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -523,17 +514,17 @@ class FloorModCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
size_t wsSize; size_t wsSize;
cnnlGetFloorModWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, cnnlGetFloorModWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
@ -556,7 +547,6 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -576,17 +566,17 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, a_dim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
a_dim.data())); a_dim.size(), a_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, b_dim.size(), bDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
b_dim.data())); b_dim.size(), b_dim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, c_dim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
c_dim.data())); c_dim.size(), c_dim.data()));
size_t wsSize; size_t wsSize;
cnnlGetSquaredDifferenceWorkspaceSize(context->cnnlHandle(), aDesc, cnnlGetSquaredDifferenceWorkspaceSize(context->cnnlHandle(), aDesc,

View File

@ -7,7 +7,6 @@ class ErfCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class ErfCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlErf_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, cnnlErf_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -7,7 +7,6 @@ class ExpCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class ExpCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlExp_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, cnnlExp_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -7,7 +7,6 @@ class FillCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<FillObj>(_op); auto op = as<FillObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -17,9 +16,9 @@ class FillCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlFill(context->cnnlHandle(), value, cDesc, cData); cnnlFill(context->cnnlHandle(), value, cDesc, cData);

View File

@ -7,7 +7,6 @@ class FloorCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class FloorCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlFloor(context->cnnlHandle(), aDesc, aData, cDesc, cData); cnnlFloor(context->cnnlHandle(), aDesc, aData, cDesc, cData);

View File

@ -7,7 +7,6 @@ class GatherCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<GatherObj>(_op); auto op = as<GatherObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -20,9 +19,9 @@ class GatherCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError( checkCnnlError(
cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST)); cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST));
@ -30,9 +29,9 @@ class GatherCnnl : public BangKernelWithoutConfig {
CNNL_DTYPE_INT32, bDim.size(), CNNL_DTYPE_INT32, bDim.size(),
bDim.data())); bDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
BangPtr wsData = context->getWorkspace(aDim.size() * 4); BangPtr wsData = context->getWorkspace(aDim.size() * 4);
context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4); context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4);

View File

@ -7,7 +7,6 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<HardtanhObj>(_op); auto op = as<HardtanhObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -20,7 +19,8 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data())); aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
dim.size(), dim.data()));
cnnlStatus_t stat = cnnlHardtanh(context->cnnlHandle(), aDesc, aData, cnnlStatus_t stat = cnnlHardtanh(context->cnnlHandle(), aDesc, aData,
max, min, aDesc, cData); max, min, aDesc, cData);

View File

@ -7,7 +7,6 @@ class L2LossCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<L2LossObj>(_op); auto op = as<L2LossObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,7 +17,8 @@ class L2LossCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data())); aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
dim.size(), dim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlL2Loss(context->cnnlHandle(), aDesc, aData, cData); cnnlL2Loss(context->cnnlHandle(), aDesc, aData, cData);

View File

@ -8,7 +8,6 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<LayerNormObj>(_op); auto op = as<LayerNormObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -29,17 +28,17 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc; cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, inDims.size(), inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
inDims.data())); inDims.size(), inDims.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&fiterDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&fiterDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
fiterDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, fiterDims.size(), fiterDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
fiterDims.data())); fiterDims.size(), fiterDims.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, outDims.size(), outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
outDims.data())); outDims.size(), outDims.data()));
size_t wsSize; size_t wsSize;
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc, cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
&wsSize); &wsSize);

View File

@ -7,7 +7,6 @@ class LogCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<LogObj>(_op); auto op = as<LogObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -33,13 +32,13 @@ class LogCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlLog_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, cnnlLog_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -7,7 +7,6 @@ class LRNCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<LRNObj>(_op); auto op = as<LRNObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -20,13 +19,13 @@ class LRNCnnl : public BangKernelWithoutConfig {
auto size = op->getSize(); auto size = op->getSize();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
size_t extra_size; size_t extra_size;
cnnlGetLrnExtraInputSize_v2(context->cnnlHandle(), cDesc, cnnlGetLrnExtraInputSize_v2(context->cnnlHandle(), cDesc,

View File

@ -8,7 +8,6 @@ class MatmulCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<MatmulObj>(_op); auto op = as<MatmulObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
auto input_num = op->numInputs(); auto input_num = op->numInputs();
@ -38,25 +37,26 @@ class MatmulCnnl : public BangKernelWithoutConfig {
int32_t transB = op->getTransB(); int32_t transB = op->getTransB();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError( checkCnnlError(cnnlSetTensorDescriptor(
cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimInputs0.size(), dimInputs0.data())); dimInputs0.size(), dimInputs0.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError( checkCnnlError(cnnlSetTensorDescriptor(
cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, bDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimInputs1.size(), dimInputs1.data())); dimInputs1.size(), dimInputs1.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError( checkCnnlError(cnnlSetTensorDescriptor(
cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimOutput.size(), dimOutput.data())); dimOutput.size(), dimOutput.data()));
if (input_num > 2) { if (input_num > 2) {
checkCnnlError(cnnlCreateTensorDescriptor(&biasDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&biasDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(
biasDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dimBias.size(), cnnlSetTensorDescriptor(biasDesc, CNNL_LAYOUT_ARRAY,
dimBias.data())); cnnlDataTypeConvert(op->getDType()),
dimBias.size(), dimBias.data()));
} }
cnnlMatMulDescriptor_t bmm_desc; cnnlMatMulDescriptor_t bmm_desc;

View File

@ -7,7 +7,6 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlNegTensor(context->cnnlHandle(), aDesc, aData, cDesc, cData); cnnlNegTensor(context->cnnlHandle(), aDesc, aData, cDesc, cData);

View File

@ -7,7 +7,6 @@ class PadCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<PadObj>(_op); auto op = as<PadObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -37,14 +36,14 @@ class PadCnnl : public BangKernelWithoutConfig {
float paddingValue = 0.0; float paddingValue = 0.0;
// input // input
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dimIn.size(), aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimIn.data())); dimIn.size(), dimIn.data()));
// output // output
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dimOut.size(), cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimOut.data())); dimOut.size(), dimOut.data()));
cnnlStatus_t stat = cnnlPad(context->cnnlHandle(), aDesc, aData, cnnlStatus_t stat = cnnlPad(context->cnnlHandle(), aDesc, aData,
paddings, &paddingValue, cDesc, cData); paddings, &paddingValue, cDesc, cData);

View File

@ -8,7 +8,6 @@ class PoolingCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<PoolingObj>(_op); auto op = as<PoolingObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>()); void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -20,8 +19,9 @@ class PoolingCnnl : public BangKernelWithoutConfig {
int inArray[4] = {n, c, h, w}; int inArray[4] = {n, c, h, w};
cnnlTensorDescriptor_t inDesc; cnnlTensorDescriptor_t inDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, inArray)); inDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
inArray));
bool mode = op->getCeilMode(); bool mode = op->getCeilMode();
// get maxpool descriptor // get maxpool descriptor
@ -37,8 +37,9 @@ class PoolingCnnl : public BangKernelWithoutConfig {
int outArray[4] = {outVec[0], outVec[1], outVec[2], outVec[3]}; int outArray[4] = {outVec[0], outVec[1], outVec[2], outVec[3]};
cnnlTensorDescriptor_t outDesc; cnnlTensorDescriptor_t outDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, outArray)); outDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), 4,
outArray));
size_t wsSize; size_t wsSize;
cnnlGetPoolingWorkspaceSize(context->cnnlHandle(), getPoolingMode(), cnnlGetPoolingWorkspaceSize(context->cnnlHandle(), getPoolingMode(),
outVec[3], outVec[2], &wsSize); outVec[3], outVec[2], &wsSize);

View File

@ -7,7 +7,6 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlReciprocal(context->cnnlHandle(), aDesc, aData, cDesc, cData); cnnlReciprocal(context->cnnlHandle(), aDesc, aData, cDesc, cData);

View File

@ -9,7 +9,6 @@ class ReduceCnnlBase : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ReduceBaseObj>(_op); auto op = as<ReduceBaseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -26,20 +25,20 @@ class ReduceCnnlBase : public BangKernelWithoutConfig {
cnnlTensorDescriptor_t inDesc, outDesc; cnnlTensorDescriptor_t inDesc, outDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, bDim.size(), outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
bDim.data())); bDim.size(), bDim.data()));
// get reduce descriptor // get reduce descriptor
cnnlReduceDescriptor_t reduceDesc; cnnlReduceDescriptor_t reduceDesc;
checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc)); checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc));
checkCnnlError(cnnlSetReduceDescriptor_v2( checkCnnlError(cnnlSetReduceDescriptor_v2(
reduceDesc, axes.data(), axes.size(), getReduceOp(), reduceDesc, axes.data(), axes.size(), getReduceOp(),
CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, cnnlDataTypeConvert(op->getDType()), CNNL_NOT_PROPAGATE_NAN,
CNNL_32BIT_INDICES, 0.0)); CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0));
// get workspace // get workspace
size_t workspaceSize = 0; size_t workspaceSize = 0;

View File

@ -7,7 +7,6 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlRsqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, cnnlRsqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -42,11 +42,13 @@ class SliceCnnl : public BangKernelWithoutConfig {
// input // input
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, aDim_size, aDim_array)); aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
aDim_size, aDim_array));
// output // output
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDim_size, cDim_array)); cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
cDim_size, cDim_array));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlStridedSlice(context->cnnlHandle(), aDesc, aData, starts_array, cnnlStridedSlice(context->cnnlHandle(), aDesc, aData, starts_array,
@ -59,6 +61,6 @@ class SliceCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl, REGISTER_KERNEL(Device::BANG, OpType::Slice, SliceCnnl,
"Slice_cnnl_BANG_Float32"); "Slice_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,7 +7,6 @@ class SplitCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<SplitObj>(_op); auto op = as<SplitObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numOutputs(); int num = op->numOutputs();
int axis = op->getDim(); int axis = op->getDim();
@ -16,13 +15,15 @@ class SplitCnnl : public BangKernelWithoutConfig {
cnnlTensorDescriptor_t desc; cnnlTensorDescriptor_t desc;
checkCnnlError(cnnlCreateTensorDescriptor(&desc)); checkCnnlError(cnnlCreateTensorDescriptor(&desc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, dim.size(), dim.data())); desc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
dim.size(), dim.data()));
cnnlTensorDescriptor_t descArray[num]; cnnlTensorDescriptor_t descArray[num];
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i])); checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(
descArray[i], CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
cnnlDataTypeConvert(op->getDType()),
op->getOutput(i)->getDims().size(), op->getOutput(i)->getDims().size(),
op->getOutput(i)->getDims().data())); op->getOutput(i)->getDims().data()));
} }

View File

@ -7,7 +7,6 @@ class SqrtCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class SqrtCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlSqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, cnnlSqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,

View File

@ -7,7 +7,6 @@ class TransposeCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<TransposeObj>(_op); auto op = as<TransposeObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -18,13 +17,13 @@ class TransposeCnnl : public BangKernelWithoutConfig {
auto dimout = op->getOutput()->getDims(); auto dimout = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dimin.size(), aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimin.data())); dimin.size(), dimin.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dimout.size(), cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dimout.data())); dimout.size(), dimout.data()));
auto permute = op->getPermute(); auto permute = op->getPermute();
cnnlTransposeDescriptor_t opDesc; cnnlTransposeDescriptor_t opDesc;
@ -53,7 +52,6 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<DepthToSpaceObj>(_op); auto op = as<DepthToSpaceObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -73,12 +71,12 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
auto dimout = op->getOutput()->getDims(); auto dimout = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, reshape.size(), aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
reshape.data())); reshape.size(), reshape.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError( checkCnnlError(cnnlSetTensorDescriptor(
cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
transpose.size(), transpose.data())); transpose.size(), transpose.data()));
cnnlTransposeDescriptor_t opDesc; cnnlTransposeDescriptor_t opDesc;

View File

@ -9,7 +9,6 @@ class TrigonCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -20,13 +19,13 @@ class TrigonCnnl : public BangKernelWithoutConfig {
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, cDim.size(), cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.data())); cDim.size(), cDim.data()));
cnnlTrigonDescriptor_t opDesc; cnnlTrigonDescriptor_t opDesc;
checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc)); checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc));

View File

@ -7,7 +7,6 @@ class WhereCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<WhereObj>(_op); auto op = as<WhereObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,21 +34,21 @@ class WhereCnnl : public BangKernelWithoutConfig {
} }
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, aDim.size(), aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, bDim.size(), bDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
bDim.data())); bDim.size(), bDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_BOOL, cDim.size(), CNNL_DTYPE_BOOL, cDim.size(),
cDim.data())); cDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&dDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&dDesc));
checkCnnlError(cnnlSetTensorDescriptor(dDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dDim.size(), dDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dDim.data())); dDim.size(), dDim.data()));
size_t wsSize; size_t wsSize;
cnnlGetSelectV2WorkspaceSize(context->cnnlHandle(), cDesc, aDesc, bDesc, cnnlGetSelectV2WorkspaceSize(context->cnnlHandle(), cDesc, aDesc, bDesc,
&wsSize); &wsSize);