forked from jiuyuan/InfiniTensor
add reduce_mean and gather on bang (#167)
* add code * fix reduce_mean * add softmax on BANG * fix gather * fix boradcast on ele kernel when dim size is zero * add where kernel and fix softmax kernel * fix convbpdata bug * fix format --------- Co-authored-by: wanghailu <wanghailu@qiyuanlab.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
50862df765
commit
f22fa2766e
|
@ -1,5 +1,6 @@
|
|||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -113,6 +114,72 @@ class PReluCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
class SoftmaxCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SoftmaxObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
|
||||
cnnlSoftmaxMode_t mode;
|
||||
size_t axis = op->getAxis();
|
||||
std::vector<int> inDim = {1, 1, 1};
|
||||
std::vector<int> outDim = inDim;
|
||||
|
||||
if (axis == 0) {
|
||||
mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION;
|
||||
inDim[0] = aDim[0];
|
||||
inDim[1] = aDim[1];
|
||||
for (size_t i = 2; i < aDim.size(); ++i) {
|
||||
inDim[2] *= aDim[i];
|
||||
}
|
||||
outDim = inDim;
|
||||
} else if (axis == aDim.size() - 1) {
|
||||
mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION;
|
||||
inDim[0] = aDim[0];
|
||||
for (size_t i = 1; i < axis; ++i) {
|
||||
inDim[1] *= aDim[i];
|
||||
}
|
||||
inDim[2] = aDim[axis];
|
||||
outDim = inDim;
|
||||
} else {
|
||||
mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION;
|
||||
for (size_t i = 0; i < axis; ++i) {
|
||||
inDim[0] *= aDim[i];
|
||||
}
|
||||
inDim[1] = aDim[axis];
|
||||
for (size_t i = axis + 1; i < aDim.size(); ++i) {
|
||||
inDim[2] *= aDim[i];
|
||||
}
|
||||
outDim = inDim;
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, inDim.size(),
|
||||
inDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, outDim.size(),
|
||||
outDim.data()));
|
||||
float alpha = 1.0;
|
||||
float beta = 0.0;
|
||||
cnnlStatus_t stat =
|
||||
cnnlSoftmaxForward_v2(context->cnnlHandle(), CNNL_SOFTMAX_ACCURATE,
|
||||
mode, CNNL_COMPUTATION_HIGH_PRECISION, &alpha,
|
||||
aDesc, aData, &beta, cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
}
|
||||
};
|
||||
|
||||
class ReluCnnl : public UnaryCnnl {
|
||||
cnnlActivationMode_t getOpType() const override {
|
||||
return CNNL_ACTIVATION_RELU;
|
||||
|
@ -135,5 +202,7 @@ REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl,
|
|||
"Sigmoid_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl,
|
||||
"Round_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Softmax, DataType::Float32, SoftmaxCnnl,
|
||||
"Softmax_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -39,24 +39,17 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
|||
if (dimOutput.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int inputs0[4] = {dimInputs0[0], dimInputs0[1], dimInputs0[2],
|
||||
dimInputs0[3]};
|
||||
int inputs1[4] = {dimInputs1[0], dimInputs1[1], dimInputs1[2],
|
||||
dimInputs1[3]};
|
||||
int output[4] = {dimOutput[0], dimOutput[1], dimOutput[2],
|
||||
dimOutput[3]};
|
||||
|
||||
// get inputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, inputs0));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs0.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, inputs1));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs1.data()));
|
||||
// get outputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, output));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimOutput.data()));
|
||||
|
||||
cnnlConvolutionBwdDataAlgo_t algo;
|
||||
cnnlGetConvolutionBackwardDataAlgorithm(
|
||||
|
@ -69,7 +62,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
|||
BangPtr wsData = context->getWorkspace(wsSize);
|
||||
|
||||
cnnlStatus_t stat = cnnlConvolutionBackwardData(
|
||||
context->cnnlHandle(), NULL, aDesc, aData, bDesc, bData, convDesc,
|
||||
context->cnnlHandle(), NULL, bDesc, bData, aDesc, aData, convDesc,
|
||||
algo, wsData, wsSize, NULL, cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
|
|
@ -21,6 +21,13 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
@ -77,6 +84,13 @@ class LogicOpCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
@ -123,6 +137,13 @@ class BitComputeCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
@ -168,6 +189,13 @@ class DivCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
@ -213,6 +241,13 @@ class MaximumCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
@ -257,6 +292,13 @@ class MinimumCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
@ -301,6 +343,13 @@ class MSELossCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
@ -351,6 +400,14 @@ class PowerCnnl : public BangKernelWithoutConfig {
|
|||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, a_dim.size(),
|
||||
|
@ -395,6 +452,13 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
@ -440,6 +504,13 @@ class FloorModCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
@ -485,6 +556,13 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
if (a_dim.size() == 0) {
|
||||
a_dim.push_back(1);
|
||||
}
|
||||
|
||||
if (b_dim.size() == 0) {
|
||||
b_dim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
#include "operators/gather.h"
|
||||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class GatherCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<GatherObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
auto cDim = op->getOutput()->getDims();
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, aDim.size(),
|
||||
aDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, bDim.size(),
|
||||
bDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, cDim.size(),
|
||||
cDim.data()));
|
||||
|
||||
BangPtr wsData = context->getWorkspace(aDim.size() * 4);
|
||||
context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4);
|
||||
|
||||
auto axis = op->getAxis();
|
||||
cnnlStatus_t stat =
|
||||
cnnlGatherV2(context->cnnlHandle(), axis, aDesc, aData,
|
||||
(int *)wsData, bDesc, (int *)bData, cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(bDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Gather, DataType::Float32, GatherCnnl,
|
||||
"Gather_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,69 @@
|
|||
#include "operators/reduce_mean.h"
|
||||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class ReduceMeanCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ReduceMeanObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto axes_set = op->getAxes();
|
||||
std::vector<int> axes;
|
||||
axes.assign(axes_set.begin(), axes_set.end());
|
||||
auto bDim = aDim;
|
||||
for (auto it : axes) {
|
||||
bDim[it] = 1;
|
||||
}
|
||||
|
||||
cnnlTensorDescriptor_t inDesc, outDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, aDim.size(),
|
||||
aDim.data()));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, bDim.size(),
|
||||
bDim.data()));
|
||||
|
||||
// get reduce descriptor
|
||||
cnnlReduceDescriptor_t reduceDesc;
|
||||
checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc));
|
||||
checkCnnlError(cnnlSetReduceDescriptor_v2(
|
||||
reduceDesc, axes.data(), axes.size(), CNNL_REDUCE_AVG,
|
||||
CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES,
|
||||
CNNL_32BIT_INDICES, 0.0));
|
||||
|
||||
// get workspace
|
||||
size_t workspaceSize = 0;
|
||||
checkCnnlError(cnnlGetReduceOpWorkspaceSize(context->cnnlHandle(),
|
||||
inDesc, outDesc, reduceDesc,
|
||||
&workspaceSize));
|
||||
int indicesSize = axes.size() * sizeof(int);
|
||||
BangPtr wsData = context->getWorkspace(workspaceSize + indicesSize);
|
||||
|
||||
BangPtr indicesData = (char *)wsData + workspaceSize;
|
||||
context->copyBlobFromCPU(indicesData, axes.data(), indicesSize);
|
||||
|
||||
// reduce
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
checkCnnlError(cnnlReduce(
|
||||
context->cnnlHandle(), reduceDesc, wsData, workspaceSize, &alpha,
|
||||
inDesc, aData, indicesSize, indicesData, &beta, outDesc, cData));
|
||||
|
||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||
// whether sync is required before destories.
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(outDesc));
|
||||
checkCnnlError(cnnlDestroyReduceDescriptor(reduceDesc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32,
|
||||
ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,73 @@
|
|||
#include "operators/where.h"
|
||||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class WhereCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<WhereObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const dData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t aDesc, bDesc, cDesc, dDesc;
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
auto cDim = op->getInputs(2)->getDims();
|
||||
auto dDim = op->getOutput()->getDims();
|
||||
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
if (cDim.size() == 0) {
|
||||
cDim.push_back(1);
|
||||
}
|
||||
if (dDim.size() == 0) {
|
||||
dDim.push_back(1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, aDim.size(),
|
||||
aDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, bDim.size(),
|
||||
bDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_BOOL, cDim.size(),
|
||||
cDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&dDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(dDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, dDim.size(),
|
||||
dDim.data()));
|
||||
size_t wsSize;
|
||||
cnnlGetSelectV2WorkspaceSize(context->cnnlHandle(), cDesc, aDesc, bDesc,
|
||||
&wsSize);
|
||||
BangPtr wsData = context->getWorkspace(wsSize);
|
||||
|
||||
cnnlStatus_t stat =
|
||||
cnnlSelectV2(context->cnnlHandle(), cDesc, cData, aDesc, aData,
|
||||
bDesc, bData, wsData, wsSize, dDesc, dData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(bDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(dDesc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Where, DataType::Float32, WhereCnnl,
|
||||
"Where_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
Loading…
Reference in New Issue