add reduce_mean and gather on bang (#167)

* add code

* fix reduce_mean

* add softmax on BANG

* fix gather

* fix boradcast on ele kernel when dim size is zero

* add where kernel and fix softmax kernel

* fix convbpdata bug

* fix format

---------

Co-authored-by: wanghailu <wanghailu@qiyuanlab.com>
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
Hardy 2023-11-10 18:02:44 +08:00 committed by GitHub
parent 50862df765
commit f22fa2766e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 349 additions and 14 deletions

View File

@ -1,5 +1,6 @@
#include "bang/bang_kernel_without_config.h" #include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h" #include "bang/bang_runtime.h"
#include "operators/softmax.h"
#include "operators/unary.h" #include "operators/unary.h"
namespace infini { namespace infini {
@ -113,6 +114,72 @@ class PReluCnnl : public BangKernelWithoutConfig {
} }
}; };
class SoftmaxCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<SoftmaxObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, cDesc;
auto aDim = op->getInputs(0)->getDims();
cnnlSoftmaxMode_t mode;
size_t axis = op->getAxis();
std::vector<int> inDim = {1, 1, 1};
std::vector<int> outDim = inDim;
if (axis == 0) {
mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION;
inDim[0] = aDim[0];
inDim[1] = aDim[1];
for (size_t i = 2; i < aDim.size(); ++i) {
inDim[2] *= aDim[i];
}
outDim = inDim;
} else if (axis == aDim.size() - 1) {
mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION;
inDim[0] = aDim[0];
for (size_t i = 1; i < axis; ++i) {
inDim[1] *= aDim[i];
}
inDim[2] = aDim[axis];
outDim = inDim;
} else {
mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION;
for (size_t i = 0; i < axis; ++i) {
inDim[0] *= aDim[i];
}
inDim[1] = aDim[axis];
for (size_t i = axis + 1; i < aDim.size(); ++i) {
inDim[2] *= aDim[i];
}
outDim = inDim;
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, inDim.size(),
inDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, outDim.size(),
outDim.data()));
float alpha = 1.0;
float beta = 0.0;
cnnlStatus_t stat =
cnnlSoftmaxForward_v2(context->cnnlHandle(), CNNL_SOFTMAX_ACCURATE,
mode, CNNL_COMPUTATION_HIGH_PRECISION, &alpha,
aDesc, aData, &beta, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
class ReluCnnl : public UnaryCnnl { class ReluCnnl : public UnaryCnnl {
cnnlActivationMode_t getOpType() const override { cnnlActivationMode_t getOpType() const override {
return CNNL_ACTIVATION_RELU; return CNNL_ACTIVATION_RELU;
@ -135,5 +202,7 @@ REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl,
"Sigmoid_cnnl_BANG_Float32"); "Sigmoid_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl, REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl,
"Round_cnnl_BANG_Float32"); "Round_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Softmax, DataType::Float32, SoftmaxCnnl,
"Softmax_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -39,24 +39,17 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
if (dimOutput.size() != 4) if (dimOutput.size() != 4)
IT_TODO_HALT(); IT_TODO_HALT();
int inputs0[4] = {dimInputs0[0], dimInputs0[1], dimInputs0[2],
dimInputs0[3]};
int inputs1[4] = {dimInputs1[0], dimInputs1[1], dimInputs1[2],
dimInputs1[3]};
int output[4] = {dimOutput[0], dimOutput[1], dimOutput[2],
dimOutput[3]};
// get inputs // get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, inputs0)); aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs0.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, inputs1)); bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs1.data()));
// get outputs // get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, output)); cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimOutput.data()));
cnnlConvolutionBwdDataAlgo_t algo; cnnlConvolutionBwdDataAlgo_t algo;
cnnlGetConvolutionBackwardDataAlgorithm( cnnlGetConvolutionBackwardDataAlgorithm(
@ -69,7 +62,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlConvolutionBackwardData( cnnlStatus_t stat = cnnlConvolutionBackwardData(
context->cnnlHandle(), NULL, aDesc, aData, bDesc, bData, convDesc, context->cnnlHandle(), NULL, bDesc, bData, aDesc, aData, convDesc,
algo, wsData, wsSize, NULL, cDesc, cData); algo, wsData, wsSize, NULL, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -21,6 +21,13 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -77,6 +84,13 @@ class LogicOpCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -123,6 +137,13 @@ class BitComputeCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -168,6 +189,13 @@ class DivCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -213,6 +241,13 @@ class MaximumCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -257,6 +292,13 @@ class MinimumCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -301,6 +343,13 @@ class MSELossCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -351,6 +400,14 @@ class PowerCnnl : public BangKernelWithoutConfig {
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, a_dim.size(), CNNL_DTYPE_FLOAT, a_dim.size(),
@ -395,6 +452,13 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -440,6 +504,13 @@ class FloorModCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -485,6 +556,13 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
if (a_dim.size() == 0) {
a_dim.push_back(1);
}
if (b_dim.size() == 0) {
b_dim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,

View File

@ -0,0 +1,53 @@
#include "operators/gather.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
namespace infini {
class GatherCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<GatherObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims();
auto cDim = op->getOutput()->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, bDim.size(),
bDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, cDim.size(),
cDim.data()));
BangPtr wsData = context->getWorkspace(aDim.size() * 4);
context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4);
auto axis = op->getAxis();
cnnlStatus_t stat =
cnnlGatherV2(context->cnnlHandle(), axis, aDesc, aData,
(int *)wsData, bDesc, (int *)bData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(bDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Gather, DataType::Float32, GatherCnnl,
"Gather_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -0,0 +1,69 @@
#include "operators/reduce_mean.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
namespace infini {
class ReduceMeanCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ReduceMeanObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
auto aDim = op->getInputs(0)->getDims();
auto axes_set = op->getAxes();
std::vector<int> axes;
axes.assign(axes_set.begin(), axes_set.end());
auto bDim = aDim;
for (auto it : axes) {
bDim[it] = 1;
}
cnnlTensorDescriptor_t inDesc, outDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, bDim.size(),
bDim.data()));
// get reduce descriptor
cnnlReduceDescriptor_t reduceDesc;
checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc));
checkCnnlError(cnnlSetReduceDescriptor_v2(
reduceDesc, axes.data(), axes.size(), CNNL_REDUCE_AVG,
CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES,
CNNL_32BIT_INDICES, 0.0));
// get workspace
size_t workspaceSize = 0;
checkCnnlError(cnnlGetReduceOpWorkspaceSize(context->cnnlHandle(),
inDesc, outDesc, reduceDesc,
&workspaceSize));
int indicesSize = axes.size() * sizeof(int);
BangPtr wsData = context->getWorkspace(workspaceSize + indicesSize);
BangPtr indicesData = (char *)wsData + workspaceSize;
context->copyBlobFromCPU(indicesData, axes.data(), indicesSize);
// reduce
float alpha = 1.f, beta = 0.f;
checkCnnlError(cnnlReduce(
context->cnnlHandle(), reduceDesc, wsData, workspaceSize, &alpha,
inDesc, aData, indicesSize, indicesData, &beta, outDesc, cData));
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(outDesc));
checkCnnlError(cnnlDestroyReduceDescriptor(reduceDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32,
ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32");
}; // namespace infini

73
src/kernels/bang/where.cc Normal file
View File

@ -0,0 +1,73 @@
#include "operators/where.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
namespace infini {
class WhereCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<WhereObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const cData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const dData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, bDesc, cDesc, dDesc;
auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims();
auto cDim = op->getInputs(2)->getDims();
auto dDim = op->getOutput()->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
if (cDim.size() == 0) {
cDim.push_back(1);
}
if (dDim.size() == 0) {
dDim.push_back(1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, bDim.size(),
bDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_BOOL, cDim.size(),
cDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&dDesc));
checkCnnlError(cnnlSetTensorDescriptor(dDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dDim.size(),
dDim.data()));
size_t wsSize;
cnnlGetSelectV2WorkspaceSize(context->cnnlHandle(), cDesc, aDesc, bDesc,
&wsSize);
BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat =
cnnlSelectV2(context->cnnlHandle(), cDesc, cData, aDesc, aData,
bDesc, bData, wsData, wsSize, dDesc, dData);
if (stat != CNNL_STATUS_SUCCESS)
return;
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(bDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(dDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Where, DataType::Float32, WhereCnnl,
"Where_cnnl_BANG_Float32");
}; // namespace infini