From 67b2bcb7d5064cd9e790f0698bb2e3563b0de280 Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Thu, 1 Feb 2024 15:02:02 +0800 Subject: [PATCH] fix mlu some kernel registration & gather op (#210) * fix: fix bang build/kernel registration | test_onnx * delete assert float * fix gather * fix CMakeLists and Reshape * fix cncl ops * add hardsigmoid/hardswish * fix * add invalid datatype exception * fix gather * fix gather indices type * fix gather/prelu/hardsigmoid on mlu * fix format * fix --------- Co-authored-by: Bolun Zhang <48948016+Chamberlain0w0@users.noreply.github.com> Co-authored-by: Haojie Wang Co-authored-by: Zhang Bolun --- CMakeLists.txt | 8 +- Makefile | 2 + include/bang/bang_common.h | 40 +++++++- pyinfinitensor/tests/test_onnx.py | 15 ++- src/kernels/bang/activation.cc | 34 ++++++- src/kernels/bang/all_gather.cc | 7 +- src/kernels/bang/all_reduce.cc | 7 +- src/kernels/bang/broadcast.cc | 7 +- src/kernels/bang/element_wise.cc | 7 +- src/kernels/bang/gather.cc | 39 +++++++- src/kernels/bang/reshape.cc | 6 +- src/kernels/bang/trigon.cc | 3 +- test/kernels/bang/test_bang_gather.cc | 139 ++++++++++++++++++++++++++ 13 files changed, 283 insertions(+), 31 deletions(-) create mode 100644 test/kernels/bang/test_bang_gather.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 70508c79..19d11183 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -245,7 +245,6 @@ if(USE_BANG) find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64") - find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) @@ -262,12 +261,13 @@ if(USE_BANG) # BangC Kernels ################################################################################ - target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++) if (BUILD_DIST) + find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64") + target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++) message(STATUS "Add BUILD_DIST, use CNCL with BANG") - add_compile_definitions(INFINI_USE_CNCL=1) - + else() + target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++) endif() endif() diff --git a/Makefile b/Makefile index 302f47b8..ff2ad0a9 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ KUNLUN ?= OFF INTELCPU ?= off BACKTRACE ?= ON TEST ?= ON +DIST ?= OFF NNET ?= OFF FORMAT_ORIGIN ?= # Docker build options @@ -29,6 +30,7 @@ CMAKE_OPT += -DUSE_BANG=$(BANG) CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DBUILD_TEST=$(TEST) +CMAKE_OPT += -DBUILD_DIST=$(DIST) CMAKE_OPT += -DBUILD_NNET=$(NNET) ifeq ($(INTELCPU), ON) diff --git a/include/bang/bang_common.h b/include/bang/bang_common.h index 6c8047ef..019c283f 100644 --- a/include/bang/bang_common.h +++ b/include/bang/bang_common.h @@ -3,6 +3,9 @@ #include "cnrt.h" #include "core/common.h" #include "core/data_type.h" +#ifdef INFINI_USE_CNCL +#include "cncl.h" +#endif #define checkBangError(call) \ { \ @@ -56,7 +59,42 @@ inline cnnlDataType_t cnnlDataTypeConvert(DataType dataType) { if (dataType == DataType::Bool) { return CNNL_DTYPE_BOOL; } - return CNNL_DTYPE_INVALID; + IT_TODO_HALT_MSG("Data type " + dataType.toString() + + " not supported in CNNL."); } +#ifdef INFINI_USE_CNCL +inline cnclDataType_t cnclDataTypeConvert(DataType dataType) { + if (dataType == DataType::Float32) { + return cnclFloat32; + } + if (dataType == DataType::Float16) { + return cnclHalf; + } + if (dataType == DataType::Int8) { + return cnclInt8; + } + if (dataType == DataType::Int16) { + return cnclInt16; + } + if (dataType == DataType::Int32) { + return cnclInt32; + } + if (dataType == DataType::UInt8) { + return cnclUint8; + } + if (dataType == DataType::UInt16) { + return cnclUint16; + } + if (dataType == DataType::UInt32) { + return cnclUint32; + } + if (dataType == DataType::BFloat16) { + return cnclBfloat16; + } + IT_TODO_HALT_MSG("Data type " + dataType.toString() + + " not supported in CNCL."); +} +#endif + } // namespace infini diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 4d9c7574..8c0dba5f 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -463,13 +463,20 @@ class TestStringMethods(unittest.TestCase): def test_split(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) split = make_node("Split", ["input"], ["output"], name="split", axis=0) - make_and_import_model(make_graph([split], "split", [input], [])) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) + make_and_import_model(make_graph([split], "split", [input], [output])) def test_split1(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) - splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1]) - split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1) - make_and_import_model(make_graph([split], "split", [input, splitAttr], [])) + splitAttr = make_tensor("split", TensorProto.INT64, [2], [2, 1]) + output1 = make_tensor_value_info("output1", TensorProto.FLOAT, [1, 2, 2, 4]) + output2 = make_tensor_value_info("output2", TensorProto.FLOAT, [1, 1, 2, 4]) + split = make_node( + "Split", ["input", "split"], ["output1", "output2"], name="split", axis=1 + ) + make_and_import_model( + make_graph([split], "split", [input], [output1, output2], [splitAttr]) + ) def test_allBroadcast(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) diff --git a/src/kernels/bang/activation.cc b/src/kernels/bang/activation.cc index 41de7534..4105b168 100644 --- a/src/kernels/bang/activation.cc +++ b/src/kernels/bang/activation.cc @@ -2,12 +2,17 @@ #include "bang/bang_runtime.h" #include "operators/softmax.h" #include "operators/unary.h" +#include namespace infini { class UnaryCnnl : public BangKernelWithoutConfig { virtual cnnlActivationMode_t getOpType() const = 0; virtual float getCoef() const = 0; virtual tuple getAlphBeta() const { return {1.f, 0.f}; } + virtual float getSlicedDim() const { return 0.0; } + virtual float getGamma() const { return 0.0; } + virtual float getScale() const { return 0.0; } + void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); @@ -30,9 +35,10 @@ class UnaryCnnl : public BangKernelWithoutConfig { cDim.size(), cDim.data())); cnnlActivationDescriptor_t opDesc; checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); - checkCnnlError(cnnlSetActivationDescriptor_v2( + checkCnnlError(cnnlSetActivationDescriptor_v5( opDesc, getOpType(), CNNL_ACTIVATION_HIGH_PRECISION, - CNNL_NOT_PROPAGATE_NAN, getCoef())); + CNNL_NOT_PROPAGATE_NAN, getCoef(), getSlicedDim(), getGamma(), + getScale(), true)); auto [alpha, beta] = getAlphBeta(); cnnlStatus_t stat = @@ -91,6 +97,10 @@ class PReluCnnl : public BangKernelWithoutConfig { auto bDim = op->getInputs(1)->getDims(); auto cDim = op->getOutput()->getDims(); + if (auto alignSize = aDim.size() - bDim.size(); alignSize) { + bDim.insert(bDim.begin(), alignSize, 1); + } + checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), @@ -215,6 +225,22 @@ class SigmoidCnnl : public UnaryCnnl { float getCoef() const override { return 0.0; } }; +class HardSwishCnnl : public UnaryCnnl { + cnnlActivationMode_t getOpType() const override { + return CNNL_ACTIVATION_HARDSWISH; + } + float getCoef() const override { return 0.0; } +}; + +class HardSigmoidCnnl : public UnaryCnnl { + cnnlActivationMode_t getOpType() const override { + return CNNL_ACTIVATION_HARDSIGMOID; + } + float getCoef() const override { return 0.0; } + float getGamma() const override { return 1.f / 6.f; } + float getScale() const override { return 0.5f; } +}; + REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG"); REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl, @@ -222,5 +248,9 @@ REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl, REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Softmax, SoftmaxCnnl, "Softmax_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::HardSigmoid, HardSigmoidCnnl, + "HardSigmoid_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::HardSwish, HardSwishCnnl, + "HardSwish_cnnl_BANG"); }; // namespace infini diff --git a/src/kernels/bang/all_gather.cc b/src/kernels/bang/all_gather.cc index f35d71d4..f02648f1 100644 --- a/src/kernels/bang/all_gather.cc +++ b/src/kernels/bang/all_gather.cc @@ -22,14 +22,15 @@ class AllGatherCNCL : public BangKernelWithoutConfig { checkBangError(cnrtMalloc(&output_temp, op->getInputs(0)->getBytes() * world_size)); size_t bytes = op->getInputs(0)->getBytes(); - size_t count = bytes / sizeof(uint8_t); + size_t count = bytes / op->getDType().getSize(); cnclComm_t comm = dynamic_cast(context->getCommunicator()) .getCnclComm(); cnrtQueue_t queue = context->getBangQueue(); - CNCL_CHECK( - cnclAllGather(input, output_temp, count, cnclUint8, comm, queue)); + CNCL_CHECK(cnclAllGather(input, output_temp, count, + cnclDataTypeConvert(op->getDType()), comm, + queue)); checkBangError(cnrtQueueSync(queue)); for (int i = 0; i < world_size; ++i) { Tensor output = op->getOutput(i); diff --git a/src/kernels/bang/all_reduce.cc b/src/kernels/bang/all_reduce.cc index c9e42c65..ca180be2 100644 --- a/src/kernels/bang/all_reduce.cc +++ b/src/kernels/bang/all_reduce.cc @@ -14,14 +14,15 @@ class AllReduceCNCL : public BangKernelWithoutConfig { void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); size_t bytes = op->getInputs(0)->getBytes(); - size_t count = bytes / sizeof(uint8_t); + size_t count = bytes / op->getDType().getSize(); cnclComm_t comm = dynamic_cast(context->getCommunicator()) .getCnclComm(); cnrtQueue_t queue = context->getBangQueue(); // checkBangError(cnrtQueueSync(queue)); - CNCL_CHECK(cnclAllReduce(input, output, count, cnclUint8, getRedOp(), - comm, queue)); + CNCL_CHECK(cnclAllReduce(input, output, count, + cnclDataTypeConvert(op->getDType()), + getRedOp(), comm, queue)); checkBangError(cnrtQueueSync(queue)); } diff --git a/src/kernels/bang/broadcast.cc b/src/kernels/bang/broadcast.cc index e0bc5879..dda92200 100644 --- a/src/kernels/bang/broadcast.cc +++ b/src/kernels/bang/broadcast.cc @@ -14,15 +14,16 @@ class BroadcastCNCL : public BangKernelWithoutConfig { void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); size_t bytes = op->getInputs(0)->getBytes(); - size_t count = bytes / sizeof(uint8_t); + size_t count = bytes / op->getDType().getSize(); cnclComm_t comm = dynamic_cast(context->getCommunicator()) .getCnclComm(); cnrtQueue_t queue = context->getBangQueue(); // TODO: Using default stream 0 for now. - CNCL_CHECK(cnclBroadcast(input, output, count, cnclUint8, op->getRoot(), - comm, queue)); + CNCL_CHECK(cnclBroadcast(input, output, count, + cnclDataTypeConvert(op->getDType()), + op->getRoot(), comm, queue)); checkBangError(cnrtQueueSync(queue)); } }; diff --git a/src/kernels/bang/element_wise.cc b/src/kernels/bang/element_wise.cc index da18a6d4..90edda99 100644 --- a/src/kernels/bang/element_wise.cc +++ b/src/kernels/bang/element_wise.cc @@ -12,6 +12,7 @@ class ElementWiseCnnl : public BangKernelWithoutConfig { const RuntimeObj *_context) const override { auto op = as(_op); auto context = dynamic_cast(_context); + auto [aAlpha, bAlpha, beta] = getAlphBeta(); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); @@ -51,12 +52,12 @@ class ElementWiseCnnl : public BangKernelWithoutConfig { CNNL_NOT_PROPAGATE_NAN)); size_t wsSize; - cnnlGetOpTensorWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, - &wsSize); + cnnlGetOpTensorWorkspaceSize_v2(context->cnnlHandle(), opDesc, &aAlpha, + aDesc, aData, &bAlpha, bDesc, bData, + &beta, cDesc, cData, &wsSize); BangPtr wsData = context->getWorkspace(wsSize); - auto [aAlpha, bAlpha, beta] = getAlphBeta(); cnnlStatus_t stat = cnnlOpTensor(context->cnnlHandle(), opDesc, &aAlpha, aDesc, aData, &bAlpha, bDesc, bData, wsData, wsSize, &beta, cDesc, cData); diff --git a/src/kernels/bang/gather.cc b/src/kernels/bang/gather.cc index 63c0a872..4071512f 100644 --- a/src/kernels/bang/gather.cc +++ b/src/kernels/bang/gather.cc @@ -23,23 +23,52 @@ class GatherCnnl : public BangKernelWithoutConfig { aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); - checkCnnlError( - cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST)); + + if (bDim.size() == 0) { + bDim.push_back(1); + } checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT32, bDim.size(), bDim.data())); + + BangPtr indices; + DataType indicesDataType = op->getInputs(1)->getDType(); + if (indicesDataType == DataType::Int64) { + // cnnlGatherV2 does not support int64 indices + int indicesSize = + op->getInputs(1)->getBytes() / indicesDataType.getSize(); + indices = context->getWorkspace(indicesSize * sizeof(int)); + cnnlTensorDescriptor_t bDescInt64; + checkCnnlError(cnnlCreateTensorDescriptor(&bDescInt64)); + checkCnnlError(cnnlSetTensorDescriptor( + bDescInt64, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT64, bDim.size(), + bDim.data())); + checkCnnlError(cnnlCastDataType(context->cnnlHandle(), bDescInt64, + bData, CNNL_CAST_INT64_TO_INT32, + bDesc, indices)); + cnrtQueueSync(context->getBangQueue()); + checkCnnlError(cnnlDestroyTensorDescriptor(bDescInt64)); + } else if (indicesDataType == DataType::Int32) { + indices = bData; + } else { + IT_TODO_HALT_MSG("Unsupported data type of indices: " + + indicesDataType.toString()); + } + checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), cDim.size(), cDim.data())); - BangPtr wsData = context->getWorkspace(aDim.size() * 4); - context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4); + BangPtr wsData = context->getWorkspace(aDim.size() * sizeof(int)); + context->copyBlobFromCPU(wsData, aDim.data(), + aDim.size() * sizeof(int)); auto axis = op->getAxis(); cnnlStatus_t stat = cnnlGatherV2(context->cnnlHandle(), axis, aDesc, aData, - (int *)wsData, bDesc, (int *)bData, cDesc, cData); + reinterpret_cast(wsData), bDesc, + reinterpret_cast(indices), cDesc, cData); if (stat != CNNL_STATUS_SUCCESS) return; diff --git a/src/kernels/bang/reshape.cc b/src/kernels/bang/reshape.cc index cd876bf1..22b217a5 100644 --- a/src/kernels/bang/reshape.cc +++ b/src/kernels/bang/reshape.cc @@ -14,8 +14,8 @@ class CopyBang : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8, - dim.size() * op->getDType().getSize(), dim.data())); + aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), + dim.size(), dim.data())); cnnlStatus_t stat = cnnlCopy(context->cnnlHandle(), aDesc, inData, aDesc, outData); if (stat != CNNL_STATUS_SUCCESS) @@ -28,5 +28,7 @@ class CopyBang : public BangKernelWithoutConfig { REGISTER_KERNEL(Device::BANG, OpType::Reshape, CopyBang, "Reshape_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Flatten, CopyBang, "Flatten_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Identity, CopyBang, "Identity_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Squeeze, CopyBang, "Squeeze_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::Unsqueeze, CopyBang, "Unsqueeze_BANG"); } // namespace infini diff --git a/src/kernels/bang/trigon.cc b/src/kernels/bang/trigon.cc index d6ce8273..02a58b8d 100644 --- a/src/kernels/bang/trigon.cc +++ b/src/kernels/bang/trigon.cc @@ -29,7 +29,8 @@ class TrigonCnnl : public BangKernelWithoutConfig { cnnlTrigonDescriptor_t opDesc; checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc)); - checkCnnlError(cnnlSetTrigonDescriptor(opDesc, getOpType())); + checkCnnlError( + cnnlSetTrigonDescriptor_v2(opDesc, getOpType(), getPrefer())); cnnlStatus_t stat = cnnlTrigonForward(context->cnnlHandle(), opDesc, aDesc, aData, cDesc, cData); diff --git a/test/kernels/bang/test_bang_gather.cc b/test/kernels/bang/test_bang_gather.cc new file mode 100644 index 00000000..09e88397 --- /dev/null +++ b/test/kernels/bang/test_bang_gather.cc @@ -0,0 +1,139 @@ +#include "bang/bang_runtime.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/gather.h" + +#include "test.h" +namespace infini { +/* +test1: +input = [ + [1, 2], + [3, 4], + [5, 6], + ] + indices = [ + [0, 1], + [1, 2], + ] + output = [ + [ + [1, 2], + [3, 4], + ], + [ + [3, 4], + [5, 6], + ], + ] + axis=0 + */ + +/* +test2 +input = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + indices = [ + [0, 2], + ] + axis = 1, + output = [ + [[0, 2]], + [[3, 5]], + [[6, 8]], + ] +*/ +/* +test3 +input=[[[ 0, 1], + [ 2, 3], + [ 4, 5], + [ 6, 7]], + + [[ 8, 9], + [10, 11], + [12, 13], + [14, 15]]] //(2,4,2) +indices=[[0],[3],[1]] //(3,1) +axis=1 +output= + +*/ + +TEST(Gather, Mlu) { + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({3, 2}, DataType::Float32); + auto index = gCpu->addTensor({2, 2}, DataType::Int32); + gCpu->dataMalloc(); + input->copyin(vector{1, 2, 3, 4, 5, 6}); + index->copyin(vector{0, 1, 1, 2}); + auto bangRuntime = make_ref(); + Graph gMlu = make_ref(bangRuntime); + + auto inputMlu = gMlu->cloneTensor(input); + auto indexMlu = gMlu->cloneTensor(index); + auto op = gMlu->addOp(inputMlu, indexMlu, nullptr, 0); + gMlu->dataMalloc(); + inputMlu->copyin(vector{1, 2, 3, 4, 5, 6}); + indexMlu->copyin(vector{0, 1, 1, 2}); + bangRuntime->run(gMlu); + + // copy output from MLU to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData(vector{1, 2, 3, 4, 3, 4, 5, 6})); + } + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({3, 3}, DataType::Float32); + auto index = gCpu->addTensor({1, 2}, DataType::Int32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + index->copyin(vector{0, 2}); + auto bangRuntime = make_ref(); + Graph gMlu = make_ref(bangRuntime); + + auto inputMlu = gMlu->cloneTensor(input); + auto indexMlu = gMlu->cloneTensor(index); + auto op = gMlu->addOp(inputMlu, indexMlu, nullptr, 1); + gMlu->dataMalloc(); + inputMlu->setData(IncrementalGenerator()); + indexMlu->copyin(vector{0, 2}); + bangRuntime->run(gMlu); + + // copy output from MLU to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData(vector{0, 2, 3, 5, 6, 8})); + } + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32); + auto index = gCpu->addTensor({3, 1}, DataType::Int32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + index->copyin(vector{0, 3, 1}); + auto bangRuntime = make_ref(); + Graph gMlu = make_ref(bangRuntime); + + auto inputMlu = gMlu->cloneTensor(input); + auto indexMlu = gMlu->cloneTensor(index); + auto op = gMlu->addOp(inputMlu, indexMlu, nullptr, 1); + gMlu->dataMalloc(); + inputMlu->setData(IncrementalGenerator()); + indexMlu->copyin(vector{0, 3, 1}); + bangRuntime->run(gMlu); + + // copy output from MLU to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData( + vector{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11})); + } +} + +} // namespace infini