fix mlu some kernel registration & gather op (#210)

* fix: fix bang build/kernel registration | test_onnx

* delete assert float

* fix gather

* fix CMakeLists and Reshape

* fix cncl ops

* add hardsigmoid/hardswish

* fix

* add invalid datatype exception

* fix gather

* fix gather indices type

* fix gather/prelu/hardsigmoid on mlu

* fix format

* fix

---------

Co-authored-by: Bolun Zhang <48948016+Chamberlain0w0@users.noreply.github.com>
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
Co-authored-by: Zhang Bolun <Chamberlain0w0@gmail.com>
This commit is contained in:
zhangyunze 2024-02-01 15:02:02 +08:00 committed by GitHub
parent 4813204a36
commit 67b2bcb7d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 283 additions and 31 deletions

View File

@ -245,7 +245,6 @@ if(USE_BANG)
find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
@ -262,12 +261,13 @@ if(USE_BANG)
# BangC Kernels # BangC Kernels
################################################################################ ################################################################################
target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
if (BUILD_DIST) if (BUILD_DIST)
find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64")
target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
message(STATUS "Add BUILD_DIST, use CNCL with BANG") message(STATUS "Add BUILD_DIST, use CNCL with BANG")
add_compile_definitions(INFINI_USE_CNCL=1) add_compile_definitions(INFINI_USE_CNCL=1)
else()
target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
endif() endif()
endif() endif()

View File

@ -7,6 +7,7 @@ KUNLUN ?= OFF
INTELCPU ?= off INTELCPU ?= off
BACKTRACE ?= ON BACKTRACE ?= ON
TEST ?= ON TEST ?= ON
DIST ?= OFF
NNET ?= OFF NNET ?= OFF
FORMAT_ORIGIN ?= FORMAT_ORIGIN ?=
# Docker build options # Docker build options
@ -29,6 +30,7 @@ CMAKE_OPT += -DUSE_BANG=$(BANG)
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
CMAKE_OPT += -DBUILD_TEST=$(TEST) CMAKE_OPT += -DBUILD_TEST=$(TEST)
CMAKE_OPT += -DBUILD_DIST=$(DIST)
CMAKE_OPT += -DBUILD_NNET=$(NNET) CMAKE_OPT += -DBUILD_NNET=$(NNET)
ifeq ($(INTELCPU), ON) ifeq ($(INTELCPU), ON)

View File

@ -3,6 +3,9 @@
#include "cnrt.h" #include "cnrt.h"
#include "core/common.h" #include "core/common.h"
#include "core/data_type.h" #include "core/data_type.h"
#ifdef INFINI_USE_CNCL
#include "cncl.h"
#endif
#define checkBangError(call) \ #define checkBangError(call) \
{ \ { \
@ -56,7 +59,42 @@ inline cnnlDataType_t cnnlDataTypeConvert(DataType dataType) {
if (dataType == DataType::Bool) { if (dataType == DataType::Bool) {
return CNNL_DTYPE_BOOL; return CNNL_DTYPE_BOOL;
} }
return CNNL_DTYPE_INVALID; IT_TODO_HALT_MSG("Data type " + dataType.toString() +
" not supported in CNNL.");
} }
#ifdef INFINI_USE_CNCL
inline cnclDataType_t cnclDataTypeConvert(DataType dataType) {
if (dataType == DataType::Float32) {
return cnclFloat32;
}
if (dataType == DataType::Float16) {
return cnclHalf;
}
if (dataType == DataType::Int8) {
return cnclInt8;
}
if (dataType == DataType::Int16) {
return cnclInt16;
}
if (dataType == DataType::Int32) {
return cnclInt32;
}
if (dataType == DataType::UInt8) {
return cnclUint8;
}
if (dataType == DataType::UInt16) {
return cnclUint16;
}
if (dataType == DataType::UInt32) {
return cnclUint32;
}
if (dataType == DataType::BFloat16) {
return cnclBfloat16;
}
IT_TODO_HALT_MSG("Data type " + dataType.toString() +
" not supported in CNCL.");
}
#endif
} // namespace infini } // namespace infini

View File

@ -463,13 +463,20 @@ class TestStringMethods(unittest.TestCase):
def test_split(self): def test_split(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
split = make_node("Split", ["input"], ["output"], name="split", axis=0) split = make_node("Split", ["input"], ["output"], name="split", axis=0)
make_and_import_model(make_graph([split], "split", [input], [])) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
make_and_import_model(make_graph([split], "split", [input], [output]))
def test_split1(self): def test_split1(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1]) splitAttr = make_tensor("split", TensorProto.INT64, [2], [2, 1])
split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1) output1 = make_tensor_value_info("output1", TensorProto.FLOAT, [1, 2, 2, 4])
make_and_import_model(make_graph([split], "split", [input, splitAttr], [])) output2 = make_tensor_value_info("output2", TensorProto.FLOAT, [1, 1, 2, 4])
split = make_node(
"Split", ["input", "split"], ["output1", "output2"], name="split", axis=1
)
make_and_import_model(
make_graph([split], "split", [input], [output1, output2], [splitAttr])
)
def test_allBroadcast(self): def test_allBroadcast(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])

View File

@ -2,12 +2,17 @@
#include "bang/bang_runtime.h" #include "bang/bang_runtime.h"
#include "operators/softmax.h" #include "operators/softmax.h"
#include "operators/unary.h" #include "operators/unary.h"
#include <iostream>
namespace infini { namespace infini {
class UnaryCnnl : public BangKernelWithoutConfig { class UnaryCnnl : public BangKernelWithoutConfig {
virtual cnnlActivationMode_t getOpType() const = 0; virtual cnnlActivationMode_t getOpType() const = 0;
virtual float getCoef() const = 0; virtual float getCoef() const = 0;
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; } virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
virtual float getSlicedDim() const { return 0.0; }
virtual float getGamma() const { return 0.0; }
virtual float getScale() const { return 0.0; }
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
@ -30,9 +35,10 @@ class UnaryCnnl : public BangKernelWithoutConfig {
cDim.size(), cDim.data())); cDim.size(), cDim.data()));
cnnlActivationDescriptor_t opDesc; cnnlActivationDescriptor_t opDesc;
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
checkCnnlError(cnnlSetActivationDescriptor_v2( checkCnnlError(cnnlSetActivationDescriptor_v5(
opDesc, getOpType(), CNNL_ACTIVATION_HIGH_PRECISION, opDesc, getOpType(), CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN, getCoef())); CNNL_NOT_PROPAGATE_NAN, getCoef(), getSlicedDim(), getGamma(),
getScale(), true));
auto [alpha, beta] = getAlphBeta(); auto [alpha, beta] = getAlphBeta();
cnnlStatus_t stat = cnnlStatus_t stat =
@ -91,6 +97,10 @@ class PReluCnnl : public BangKernelWithoutConfig {
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
auto cDim = op->getOutput()->getDims(); auto cDim = op->getOutput()->getDims();
if (auto alignSize = aDim.size() - bDim.size(); alignSize) {
bDim.insert(bDim.begin(), alignSize, 1);
}
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
@ -215,6 +225,22 @@ class SigmoidCnnl : public UnaryCnnl {
float getCoef() const override { return 0.0; } float getCoef() const override { return 0.0; }
}; };
class HardSwishCnnl : public UnaryCnnl {
cnnlActivationMode_t getOpType() const override {
return CNNL_ACTIVATION_HARDSWISH;
}
float getCoef() const override { return 0.0; }
};
class HardSigmoidCnnl : public UnaryCnnl {
cnnlActivationMode_t getOpType() const override {
return CNNL_ACTIVATION_HARDSIGMOID;
}
float getCoef() const override { return 0.0; }
float getGamma() const override { return 1.f / 6.f; }
float getScale() const override { return 0.5f; }
};
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG"); REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl, REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
@ -222,5 +248,9 @@ REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Softmax, SoftmaxCnnl, REGISTER_KERNEL(Device::BANG, OpType::Softmax, SoftmaxCnnl,
"Softmax_cnnl_BANG"); "Softmax_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::HardSigmoid, HardSigmoidCnnl,
"HardSigmoid_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::HardSwish, HardSwishCnnl,
"HardSwish_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -22,14 +22,15 @@ class AllGatherCNCL : public BangKernelWithoutConfig {
checkBangError(cnrtMalloc(&output_temp, checkBangError(cnrtMalloc(&output_temp,
op->getInputs(0)->getBytes() * world_size)); op->getInputs(0)->getBytes() * world_size));
size_t bytes = op->getInputs(0)->getBytes(); size_t bytes = op->getInputs(0)->getBytes();
size_t count = bytes / sizeof(uint8_t); size_t count = bytes / op->getDType().getSize();
cnclComm_t comm = cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator()) dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm(); .getCnclComm();
cnrtQueue_t queue = context->getBangQueue(); cnrtQueue_t queue = context->getBangQueue();
CNCL_CHECK( CNCL_CHECK(cnclAllGather(input, output_temp, count,
cnclAllGather(input, output_temp, count, cnclUint8, comm, queue)); cnclDataTypeConvert(op->getDType()), comm,
queue));
checkBangError(cnrtQueueSync(queue)); checkBangError(cnrtQueueSync(queue));
for (int i = 0; i < world_size; ++i) { for (int i = 0; i < world_size; ++i) {
Tensor output = op->getOutput(i); Tensor output = op->getOutput(i);

View File

@ -14,14 +14,15 @@ class AllReduceCNCL : public BangKernelWithoutConfig {
void *input = op->getInputs(0)->getRawDataPtr<void *>(); void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>(); void *output = op->getOutput()->getRawDataPtr<void *>();
size_t bytes = op->getInputs(0)->getBytes(); size_t bytes = op->getInputs(0)->getBytes();
size_t count = bytes / sizeof(uint8_t); size_t count = bytes / op->getDType().getSize();
cnclComm_t comm = cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator()) dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm(); .getCnclComm();
cnrtQueue_t queue = context->getBangQueue(); cnrtQueue_t queue = context->getBangQueue();
// checkBangError(cnrtQueueSync(queue)); // checkBangError(cnrtQueueSync(queue));
CNCL_CHECK(cnclAllReduce(input, output, count, cnclUint8, getRedOp(), CNCL_CHECK(cnclAllReduce(input, output, count,
comm, queue)); cnclDataTypeConvert(op->getDType()),
getRedOp(), comm, queue));
checkBangError(cnrtQueueSync(queue)); checkBangError(cnrtQueueSync(queue));
} }

View File

@ -14,15 +14,16 @@ class BroadcastCNCL : public BangKernelWithoutConfig {
void *input = op->getInputs(0)->getRawDataPtr<void *>(); void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>(); void *output = op->getOutput()->getRawDataPtr<void *>();
size_t bytes = op->getInputs(0)->getBytes(); size_t bytes = op->getInputs(0)->getBytes();
size_t count = bytes / sizeof(uint8_t); size_t count = bytes / op->getDType().getSize();
cnclComm_t comm = cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator()) dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm(); .getCnclComm();
cnrtQueue_t queue = context->getBangQueue(); cnrtQueue_t queue = context->getBangQueue();
// TODO: Using default stream 0 for now. // TODO: Using default stream 0 for now.
CNCL_CHECK(cnclBroadcast(input, output, count, cnclUint8, op->getRoot(), CNCL_CHECK(cnclBroadcast(input, output, count,
comm, queue)); cnclDataTypeConvert(op->getDType()),
op->getRoot(), comm, queue));
checkBangError(cnrtQueueSync(queue)); checkBangError(cnrtQueueSync(queue));
} }
}; };

View File

@ -12,6 +12,7 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
auto [aAlpha, bAlpha, beta] = getAlphBeta();
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
@ -51,12 +52,12 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
CNNL_NOT_PROPAGATE_NAN)); CNNL_NOT_PROPAGATE_NAN));
size_t wsSize; size_t wsSize;
cnnlGetOpTensorWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, cnnlGetOpTensorWorkspaceSize_v2(context->cnnlHandle(), opDesc, &aAlpha,
&wsSize); aDesc, aData, &bAlpha, bDesc, bData,
&beta, cDesc, cData, &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
auto [aAlpha, bAlpha, beta] = getAlphBeta();
cnnlStatus_t stat = cnnlOpTensor(context->cnnlHandle(), opDesc, &aAlpha, cnnlStatus_t stat = cnnlOpTensor(context->cnnlHandle(), opDesc, &aAlpha,
aDesc, aData, &bAlpha, bDesc, bData, aDesc, aData, &bAlpha, bDesc, bData,
wsData, wsSize, &beta, cDesc, cData); wsData, wsSize, &beta, cDesc, cData);

View File

@ -23,23 +23,52 @@ class GatherCnnl : public BangKernelWithoutConfig {
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data())); aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(
cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST)); if (bDim.size() == 0) {
bDim.push_back(1);
}
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, bDim.size(), CNNL_DTYPE_INT32, bDim.size(),
bDim.data())); bDim.data()));
BangPtr indices;
DataType indicesDataType = op->getInputs(1)->getDType();
if (indicesDataType == DataType::Int64) {
// cnnlGatherV2 does not support int64 indices
int indicesSize =
op->getInputs(1)->getBytes() / indicesDataType.getSize();
indices = context->getWorkspace(indicesSize * sizeof(int));
cnnlTensorDescriptor_t bDescInt64;
checkCnnlError(cnnlCreateTensorDescriptor(&bDescInt64));
checkCnnlError(cnnlSetTensorDescriptor(
bDescInt64, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT64, bDim.size(),
bDim.data()));
checkCnnlError(cnnlCastDataType(context->cnnlHandle(), bDescInt64,
bData, CNNL_CAST_INT64_TO_INT32,
bDesc, indices));
cnrtQueueSync(context->getBangQueue());
checkCnnlError(cnnlDestroyTensorDescriptor(bDescInt64));
} else if (indicesDataType == DataType::Int32) {
indices = bData;
} else {
IT_TODO_HALT_MSG("Unsupported data type of indices: " +
indicesDataType.toString());
}
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data())); cDim.size(), cDim.data()));
BangPtr wsData = context->getWorkspace(aDim.size() * 4); BangPtr wsData = context->getWorkspace(aDim.size() * sizeof(int));
context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4); context->copyBlobFromCPU(wsData, aDim.data(),
aDim.size() * sizeof(int));
auto axis = op->getAxis(); auto axis = op->getAxis();
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlGatherV2(context->cnnlHandle(), axis, aDesc, aData, cnnlGatherV2(context->cnnlHandle(), axis, aDesc, aData,
(int *)wsData, bDesc, (int *)bData, cDesc, cData); reinterpret_cast<const int *>(wsData), bDesc,
reinterpret_cast<const int *>(indices), cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -14,8 +14,8 @@ class CopyBang : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8, aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
dim.size() * op->getDType().getSize(), dim.data())); dim.size(), dim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlCopy(context->cnnlHandle(), aDesc, inData, aDesc, outData); cnnlCopy(context->cnnlHandle(), aDesc, inData, aDesc, outData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
@ -28,5 +28,7 @@ class CopyBang : public BangKernelWithoutConfig {
REGISTER_KERNEL(Device::BANG, OpType::Reshape, CopyBang, "Reshape_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Reshape, CopyBang, "Reshape_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Flatten, CopyBang, "Flatten_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Flatten, CopyBang, "Flatten_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Identity, CopyBang, "Identity_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Identity, CopyBang, "Identity_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Squeeze, CopyBang, "Squeeze_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Unsqueeze, CopyBang, "Unsqueeze_BANG");
} // namespace infini } // namespace infini

View File

@ -29,7 +29,8 @@ class TrigonCnnl : public BangKernelWithoutConfig {
cnnlTrigonDescriptor_t opDesc; cnnlTrigonDescriptor_t opDesc;
checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc)); checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc));
checkCnnlError(cnnlSetTrigonDescriptor(opDesc, getOpType())); checkCnnlError(
cnnlSetTrigonDescriptor_v2(opDesc, getOpType(), getPrefer()));
cnnlStatus_t stat = cnnlTrigonForward(context->cnnlHandle(), opDesc, cnnlStatus_t stat = cnnlTrigonForward(context->cnnlHandle(), opDesc,
aDesc, aData, cDesc, cData); aDesc, aData, cDesc, cData);

View File

@ -0,0 +1,139 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/gather.h"
#include "test.h"
namespace infini {
/*
test1:
input = [
[1, 2],
[3, 4],
[5, 6],
]
indices = [
[0, 1],
[1, 2],
]
output = [
[
[1, 2],
[3, 4],
],
[
[3, 4],
[5, 6],
],
]
axis=0
*/
/*
test2
input = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
indices = [
[0, 2],
]
axis = 1,
output = [
[[0, 2]],
[[3, 5]],
[[6, 8]],
]
*/
/*
test3
input=[[[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11],
[12, 13],
[14, 15]]] //(2,4,2)
indices=[[0],[3],[1]] //(3,1)
axis=1
output=
*/
TEST(Gather, Mlu) {
{
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({3, 2}, DataType::Float32);
auto index = gCpu->addTensor({2, 2}, DataType::Int32);
gCpu->dataMalloc();
input->copyin(vector<float>{1, 2, 3, 4, 5, 6});
index->copyin(vector<int>{0, 1, 1, 2});
auto bangRuntime = make_ref<BangRuntimeObj>();
Graph gMlu = make_ref<GraphObj>(bangRuntime);
auto inputMlu = gMlu->cloneTensor(input);
auto indexMlu = gMlu->cloneTensor(index);
auto op = gMlu->addOp<GatherObj>(inputMlu, indexMlu, nullptr, 0);
gMlu->dataMalloc();
inputMlu->copyin(vector<float>{1, 2, 3, 4, 5, 6});
indexMlu->copyin(vector<int>{0, 1, 1, 2});
bangRuntime->run(gMlu);
// copy output from MLU to CPU
auto oCpu = gCpu->cloneTensor(op->getOutput());
EXPECT_TRUE(oCpu->equalData(vector<float>{1, 2, 3, 4, 3, 4, 5, 6}));
}
{
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({3, 3}, DataType::Float32);
auto index = gCpu->addTensor({1, 2}, DataType::Int32);
gCpu->dataMalloc();
input->setData(IncrementalGenerator());
index->copyin(vector<int>{0, 2});
auto bangRuntime = make_ref<BangRuntimeObj>();
Graph gMlu = make_ref<GraphObj>(bangRuntime);
auto inputMlu = gMlu->cloneTensor(input);
auto indexMlu = gMlu->cloneTensor(index);
auto op = gMlu->addOp<GatherObj>(inputMlu, indexMlu, nullptr, 1);
gMlu->dataMalloc();
inputMlu->setData(IncrementalGenerator());
indexMlu->copyin(vector<int>{0, 2});
bangRuntime->run(gMlu);
// copy output from MLU to CPU
auto oCpu = gCpu->cloneTensor(op->getOutput());
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 2, 3, 5, 6, 8}));
}
{
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32);
auto index = gCpu->addTensor({3, 1}, DataType::Int32);
gCpu->dataMalloc();
input->setData(IncrementalGenerator());
index->copyin(vector<int>{0, 3, 1});
auto bangRuntime = make_ref<BangRuntimeObj>();
Graph gMlu = make_ref<GraphObj>(bangRuntime);
auto inputMlu = gMlu->cloneTensor(input);
auto indexMlu = gMlu->cloneTensor(index);
auto op = gMlu->addOp<GatherObj>(inputMlu, indexMlu, nullptr, 1);
gMlu->dataMalloc();
inputMlu->setData(IncrementalGenerator());
indexMlu->copyin(vector<int>{0, 3, 1});
bangRuntime->run(gMlu);
// copy output from MLU to CPU
auto oCpu = gCpu->cloneTensor(op->getOutput());
EXPECT_TRUE(oCpu->equalData(
vector<float>{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11}));
}
}
} // namespace infini