forked from jiuyuan/InfiniTensor
Merge branch 'master' into rope_and_silu
This commit is contained in:
commit
b0876a13ce
|
@ -245,7 +245,6 @@ if(USE_BANG)
|
|||
find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64")
|
||||
find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64")
|
||||
find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64")
|
||||
find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||
|
||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||
|
@ -262,12 +261,13 @@ if(USE_BANG)
|
|||
# BangC Kernels
|
||||
################################################################################
|
||||
|
||||
target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
|
||||
if (BUILD_DIST)
|
||||
find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64")
|
||||
target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
|
||||
message(STATUS "Add BUILD_DIST, use CNCL with BANG")
|
||||
|
||||
add_compile_definitions(INFINI_USE_CNCL=1)
|
||||
|
||||
else()
|
||||
target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
2
Makefile
2
Makefile
|
@ -7,6 +7,7 @@ KUNLUN ?= OFF
|
|||
INTELCPU ?= off
|
||||
BACKTRACE ?= ON
|
||||
TEST ?= ON
|
||||
DIST ?= OFF
|
||||
NNET ?= OFF
|
||||
FORMAT_ORIGIN ?=
|
||||
# Docker build options
|
||||
|
@ -29,6 +30,7 @@ CMAKE_OPT += -DUSE_BANG=$(BANG)
|
|||
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
||||
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
||||
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
||||
CMAKE_OPT += -DBUILD_DIST=$(DIST)
|
||||
CMAKE_OPT += -DBUILD_NNET=$(NNET)
|
||||
|
||||
ifeq ($(INTELCPU), ON)
|
||||
|
|
|
@ -3,6 +3,9 @@
|
|||
#include "cnrt.h"
|
||||
#include "core/common.h"
|
||||
#include "core/data_type.h"
|
||||
#ifdef INFINI_USE_CNCL
|
||||
#include "cncl.h"
|
||||
#endif
|
||||
|
||||
#define checkBangError(call) \
|
||||
{ \
|
||||
|
@ -56,7 +59,42 @@ inline cnnlDataType_t cnnlDataTypeConvert(DataType dataType) {
|
|||
if (dataType == DataType::Bool) {
|
||||
return CNNL_DTYPE_BOOL;
|
||||
}
|
||||
return CNNL_DTYPE_INVALID;
|
||||
IT_TODO_HALT_MSG("Data type " + dataType.toString() +
|
||||
" not supported in CNNL.");
|
||||
}
|
||||
|
||||
#ifdef INFINI_USE_CNCL
|
||||
inline cnclDataType_t cnclDataTypeConvert(DataType dataType) {
|
||||
if (dataType == DataType::Float32) {
|
||||
return cnclFloat32;
|
||||
}
|
||||
if (dataType == DataType::Float16) {
|
||||
return cnclHalf;
|
||||
}
|
||||
if (dataType == DataType::Int8) {
|
||||
return cnclInt8;
|
||||
}
|
||||
if (dataType == DataType::Int16) {
|
||||
return cnclInt16;
|
||||
}
|
||||
if (dataType == DataType::Int32) {
|
||||
return cnclInt32;
|
||||
}
|
||||
if (dataType == DataType::UInt8) {
|
||||
return cnclUint8;
|
||||
}
|
||||
if (dataType == DataType::UInt16) {
|
||||
return cnclUint16;
|
||||
}
|
||||
if (dataType == DataType::UInt32) {
|
||||
return cnclUint32;
|
||||
}
|
||||
if (dataType == DataType::BFloat16) {
|
||||
return cnclBfloat16;
|
||||
}
|
||||
IT_TODO_HALT_MSG("Data type " + dataType.toString() +
|
||||
" not supported in CNCL.");
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -463,13 +463,20 @@ class TestStringMethods(unittest.TestCase):
|
|||
def test_split(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
||||
make_and_import_model(make_graph([split], "split", [input], []))
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
make_and_import_model(make_graph([split], "split", [input], [output]))
|
||||
|
||||
def test_split1(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1])
|
||||
split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1)
|
||||
make_and_import_model(make_graph([split], "split", [input, splitAttr], []))
|
||||
splitAttr = make_tensor("split", TensorProto.INT64, [2], [2, 1])
|
||||
output1 = make_tensor_value_info("output1", TensorProto.FLOAT, [1, 2, 2, 4])
|
||||
output2 = make_tensor_value_info("output2", TensorProto.FLOAT, [1, 1, 2, 4])
|
||||
split = make_node(
|
||||
"Split", ["input", "split"], ["output1", "output2"], name="split", axis=1
|
||||
)
|
||||
make_and_import_model(
|
||||
make_graph([split], "split", [input], [output1, output2], [splitAttr])
|
||||
)
|
||||
|
||||
def test_allBroadcast(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
|
|
@ -2,12 +2,17 @@
|
|||
#include "bang/bang_runtime.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/unary.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace infini {
|
||||
class UnaryCnnl : public BangKernelWithoutConfig {
|
||||
virtual cnnlActivationMode_t getOpType() const = 0;
|
||||
virtual float getCoef() const = 0;
|
||||
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
|
||||
virtual float getSlicedDim() const { return 0.0; }
|
||||
virtual float getGamma() const { return 0.0; }
|
||||
virtual float getScale() const { return 0.0; }
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
|
@ -30,9 +35,10 @@ class UnaryCnnl : public BangKernelWithoutConfig {
|
|||
cDim.size(), cDim.data()));
|
||||
cnnlActivationDescriptor_t opDesc;
|
||||
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
|
||||
checkCnnlError(cnnlSetActivationDescriptor_v2(
|
||||
checkCnnlError(cnnlSetActivationDescriptor_v5(
|
||||
opDesc, getOpType(), CNNL_ACTIVATION_HIGH_PRECISION,
|
||||
CNNL_NOT_PROPAGATE_NAN, getCoef()));
|
||||
CNNL_NOT_PROPAGATE_NAN, getCoef(), getSlicedDim(), getGamma(),
|
||||
getScale(), true));
|
||||
|
||||
auto [alpha, beta] = getAlphBeta();
|
||||
cnnlStatus_t stat =
|
||||
|
@ -91,6 +97,10 @@ class PReluCnnl : public BangKernelWithoutConfig {
|
|||
auto bDim = op->getInputs(1)->getDims();
|
||||
auto cDim = op->getOutput()->getDims();
|
||||
|
||||
if (auto alignSize = aDim.size() - bDim.size(); alignSize) {
|
||||
bDim.insert(bDim.begin(), alignSize, 1);
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||
|
@ -215,6 +225,22 @@ class SigmoidCnnl : public UnaryCnnl {
|
|||
float getCoef() const override { return 0.0; }
|
||||
};
|
||||
|
||||
class HardSwishCnnl : public UnaryCnnl {
|
||||
cnnlActivationMode_t getOpType() const override {
|
||||
return CNNL_ACTIVATION_HARDSWISH;
|
||||
}
|
||||
float getCoef() const override { return 0.0; }
|
||||
};
|
||||
|
||||
class HardSigmoidCnnl : public UnaryCnnl {
|
||||
cnnlActivationMode_t getOpType() const override {
|
||||
return CNNL_ACTIVATION_HARDSIGMOID;
|
||||
}
|
||||
float getCoef() const override { return 0.0; }
|
||||
float getGamma() const override { return 1.f / 6.f; }
|
||||
float getScale() const override { return 0.5f; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
|
||||
|
@ -222,5 +248,9 @@ REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
|
|||
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Softmax, SoftmaxCnnl,
|
||||
"Softmax_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::HardSigmoid, HardSigmoidCnnl,
|
||||
"HardSigmoid_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::HardSwish, HardSwishCnnl,
|
||||
"HardSwish_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -22,14 +22,15 @@ class AllGatherCNCL : public BangKernelWithoutConfig {
|
|||
checkBangError(cnrtMalloc(&output_temp,
|
||||
op->getInputs(0)->getBytes() * world_size));
|
||||
size_t bytes = op->getInputs(0)->getBytes();
|
||||
size_t count = bytes / sizeof(uint8_t);
|
||||
size_t count = bytes / op->getDType().getSize();
|
||||
|
||||
cnclComm_t comm =
|
||||
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
|
||||
.getCnclComm();
|
||||
cnrtQueue_t queue = context->getBangQueue();
|
||||
CNCL_CHECK(
|
||||
cnclAllGather(input, output_temp, count, cnclUint8, comm, queue));
|
||||
CNCL_CHECK(cnclAllGather(input, output_temp, count,
|
||||
cnclDataTypeConvert(op->getDType()), comm,
|
||||
queue));
|
||||
checkBangError(cnrtQueueSync(queue));
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
Tensor output = op->getOutput(i);
|
||||
|
|
|
@ -14,14 +14,15 @@ class AllReduceCNCL : public BangKernelWithoutConfig {
|
|||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||
size_t bytes = op->getInputs(0)->getBytes();
|
||||
size_t count = bytes / sizeof(uint8_t);
|
||||
size_t count = bytes / op->getDType().getSize();
|
||||
cnclComm_t comm =
|
||||
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
|
||||
.getCnclComm();
|
||||
cnrtQueue_t queue = context->getBangQueue();
|
||||
// checkBangError(cnrtQueueSync(queue));
|
||||
CNCL_CHECK(cnclAllReduce(input, output, count, cnclUint8, getRedOp(),
|
||||
comm, queue));
|
||||
CNCL_CHECK(cnclAllReduce(input, output, count,
|
||||
cnclDataTypeConvert(op->getDType()),
|
||||
getRedOp(), comm, queue));
|
||||
checkBangError(cnrtQueueSync(queue));
|
||||
}
|
||||
|
||||
|
|
|
@ -14,15 +14,16 @@ class BroadcastCNCL : public BangKernelWithoutConfig {
|
|||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||
size_t bytes = op->getInputs(0)->getBytes();
|
||||
size_t count = bytes / sizeof(uint8_t);
|
||||
size_t count = bytes / op->getDType().getSize();
|
||||
|
||||
cnclComm_t comm =
|
||||
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
|
||||
.getCnclComm();
|
||||
cnrtQueue_t queue = context->getBangQueue();
|
||||
// TODO: Using default stream 0 for now.
|
||||
CNCL_CHECK(cnclBroadcast(input, output, count, cnclUint8, op->getRoot(),
|
||||
comm, queue));
|
||||
CNCL_CHECK(cnclBroadcast(input, output, count,
|
||||
cnclDataTypeConvert(op->getDType()),
|
||||
op->getRoot(), comm, queue));
|
||||
checkBangError(cnrtQueueSync(queue));
|
||||
}
|
||||
};
|
||||
|
|
|
@ -12,6 +12,7 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
|
|||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
auto [aAlpha, bAlpha, beta] = getAlphBeta();
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
|
@ -51,12 +52,12 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
|
|||
CNNL_NOT_PROPAGATE_NAN));
|
||||
|
||||
size_t wsSize;
|
||||
cnnlGetOpTensorWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
|
||||
&wsSize);
|
||||
cnnlGetOpTensorWorkspaceSize_v2(context->cnnlHandle(), opDesc, &aAlpha,
|
||||
aDesc, aData, &bAlpha, bDesc, bData,
|
||||
&beta, cDesc, cData, &wsSize);
|
||||
|
||||
BangPtr wsData = context->getWorkspace(wsSize);
|
||||
|
||||
auto [aAlpha, bAlpha, beta] = getAlphBeta();
|
||||
cnnlStatus_t stat = cnnlOpTensor(context->cnnlHandle(), opDesc, &aAlpha,
|
||||
aDesc, aData, &bAlpha, bDesc, bData,
|
||||
wsData, wsSize, &beta, cDesc, cData);
|
||||
|
|
|
@ -23,23 +23,52 @@ class GatherCnnl : public BangKernelWithoutConfig {
|
|||
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||
aDim.size(), aDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
||||
checkCnnlError(
|
||||
cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST));
|
||||
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, bDim.size(),
|
||||
bDim.data()));
|
||||
|
||||
BangPtr indices;
|
||||
DataType indicesDataType = op->getInputs(1)->getDType();
|
||||
if (indicesDataType == DataType::Int64) {
|
||||
// cnnlGatherV2 does not support int64 indices
|
||||
int indicesSize =
|
||||
op->getInputs(1)->getBytes() / indicesDataType.getSize();
|
||||
indices = context->getWorkspace(indicesSize * sizeof(int));
|
||||
cnnlTensorDescriptor_t bDescInt64;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDescInt64));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
bDescInt64, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT64, bDim.size(),
|
||||
bDim.data()));
|
||||
checkCnnlError(cnnlCastDataType(context->cnnlHandle(), bDescInt64,
|
||||
bData, CNNL_CAST_INT64_TO_INT32,
|
||||
bDesc, indices));
|
||||
cnrtQueueSync(context->getBangQueue());
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(bDescInt64));
|
||||
} else if (indicesDataType == DataType::Int32) {
|
||||
indices = bData;
|
||||
} else {
|
||||
IT_TODO_HALT_MSG("Unsupported data type of indices: " +
|
||||
indicesDataType.toString());
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||
cDim.size(), cDim.data()));
|
||||
|
||||
BangPtr wsData = context->getWorkspace(aDim.size() * 4);
|
||||
context->copyBlobFromCPU(wsData, aDim.data(), aDim.size() * 4);
|
||||
BangPtr wsData = context->getWorkspace(aDim.size() * sizeof(int));
|
||||
context->copyBlobFromCPU(wsData, aDim.data(),
|
||||
aDim.size() * sizeof(int));
|
||||
|
||||
auto axis = op->getAxis();
|
||||
cnnlStatus_t stat =
|
||||
cnnlGatherV2(context->cnnlHandle(), axis, aDesc, aData,
|
||||
(int *)wsData, bDesc, (int *)bData, cDesc, cData);
|
||||
reinterpret_cast<const int *>(wsData), bDesc,
|
||||
reinterpret_cast<const int *>(indices), cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@ class CopyBang : public BangKernelWithoutConfig {
|
|||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8,
|
||||
dim.size() * op->getDType().getSize(), dim.data()));
|
||||
aDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||
dim.size(), dim.data()));
|
||||
cnnlStatus_t stat =
|
||||
cnnlCopy(context->cnnlHandle(), aDesc, inData, aDesc, outData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
|
@ -28,5 +28,7 @@ class CopyBang : public BangKernelWithoutConfig {
|
|||
REGISTER_KERNEL(Device::BANG, OpType::Reshape, CopyBang, "Reshape_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Flatten, CopyBang, "Flatten_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Identity, CopyBang, "Identity_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Squeeze, CopyBang, "Squeeze_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Unsqueeze, CopyBang, "Unsqueeze_BANG");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -29,7 +29,8 @@ class TrigonCnnl : public BangKernelWithoutConfig {
|
|||
|
||||
cnnlTrigonDescriptor_t opDesc;
|
||||
checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc));
|
||||
checkCnnlError(cnnlSetTrigonDescriptor(opDesc, getOpType()));
|
||||
checkCnnlError(
|
||||
cnnlSetTrigonDescriptor_v2(opDesc, getOpType(), getPrefer()));
|
||||
|
||||
cnnlStatus_t stat = cnnlTrigonForward(context->cnnlHandle(), opDesc,
|
||||
aDesc, aData, cDesc, cData);
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#include "operators/reshape.h"
|
||||
#include "core/kernel.h"
|
||||
#include "operators/squeeze.h"
|
||||
#include "operators/unsqueeze.h"
|
||||
|
||||
namespace infini {
|
||||
class NaiveIdentity : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto size = _op->getInputs()[0]->getBytes();
|
||||
void *inptr = _op->getInputs(0)->getRawDataPtr<void *>();
|
||||
void *outptr = _op->getOutput()->getRawDataPtr<void *>();
|
||||
|
||||
std::memcpy(outptr, inptr, size);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Reshape, NaiveIdentity,
|
||||
"ReshapeNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Identity, NaiveIdentity,
|
||||
"IdentityNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Unsqueeze, NaiveIdentity,
|
||||
"UnsqueezeNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Squeeze, NaiveIdentity,
|
||||
"SqueezeNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Flatten, NaiveIdentity,
|
||||
"FlattenNaive_CPU");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,139 @@
|
|||
#include "bang/bang_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/gather.h"
|
||||
|
||||
#include "test.h"
|
||||
namespace infini {
|
||||
/*
|
||||
test1:
|
||||
input = [
|
||||
[1, 2],
|
||||
[3, 4],
|
||||
[5, 6],
|
||||
]
|
||||
indices = [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
]
|
||||
output = [
|
||||
[
|
||||
[1, 2],
|
||||
[3, 4],
|
||||
],
|
||||
[
|
||||
[3, 4],
|
||||
[5, 6],
|
||||
],
|
||||
]
|
||||
axis=0
|
||||
*/
|
||||
|
||||
/*
|
||||
test2
|
||||
input = [
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
]
|
||||
indices = [
|
||||
[0, 2],
|
||||
]
|
||||
axis = 1,
|
||||
output = [
|
||||
[[0, 2]],
|
||||
[[3, 5]],
|
||||
[[6, 8]],
|
||||
]
|
||||
*/
|
||||
/*
|
||||
test3
|
||||
input=[[[ 0, 1],
|
||||
[ 2, 3],
|
||||
[ 4, 5],
|
||||
[ 6, 7]],
|
||||
|
||||
[[ 8, 9],
|
||||
[10, 11],
|
||||
[12, 13],
|
||||
[14, 15]]] //(2,4,2)
|
||||
indices=[[0],[3],[1]] //(3,1)
|
||||
axis=1
|
||||
output=
|
||||
|
||||
*/
|
||||
|
||||
TEST(Gather, Mlu) {
|
||||
{
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto input = gCpu->addTensor({3, 2}, DataType::Float32);
|
||||
auto index = gCpu->addTensor({2, 2}, DataType::Int32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
||||
index->copyin(vector<int>{0, 1, 1, 2});
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
Graph gMlu = make_ref<GraphObj>(bangRuntime);
|
||||
|
||||
auto inputMlu = gMlu->cloneTensor(input);
|
||||
auto indexMlu = gMlu->cloneTensor(index);
|
||||
auto op = gMlu->addOp<GatherObj>(inputMlu, indexMlu, nullptr, 0);
|
||||
gMlu->dataMalloc();
|
||||
inputMlu->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
||||
indexMlu->copyin(vector<int>{0, 1, 1, 2});
|
||||
bangRuntime->run(gMlu);
|
||||
|
||||
// copy output from MLU to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{1, 2, 3, 4, 3, 4, 5, 6}));
|
||||
}
|
||||
{
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto input = gCpu->addTensor({3, 3}, DataType::Float32);
|
||||
auto index = gCpu->addTensor({1, 2}, DataType::Int32);
|
||||
gCpu->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
index->copyin(vector<int>{0, 2});
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
Graph gMlu = make_ref<GraphObj>(bangRuntime);
|
||||
|
||||
auto inputMlu = gMlu->cloneTensor(input);
|
||||
auto indexMlu = gMlu->cloneTensor(index);
|
||||
auto op = gMlu->addOp<GatherObj>(inputMlu, indexMlu, nullptr, 1);
|
||||
gMlu->dataMalloc();
|
||||
inputMlu->setData(IncrementalGenerator());
|
||||
indexMlu->copyin(vector<int>{0, 2});
|
||||
bangRuntime->run(gMlu);
|
||||
|
||||
// copy output from MLU to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 2, 3, 5, 6, 8}));
|
||||
}
|
||||
{
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32);
|
||||
auto index = gCpu->addTensor({3, 1}, DataType::Int32);
|
||||
gCpu->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
index->copyin(vector<int>{0, 3, 1});
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
Graph gMlu = make_ref<GraphObj>(bangRuntime);
|
||||
|
||||
auto inputMlu = gMlu->cloneTensor(input);
|
||||
auto indexMlu = gMlu->cloneTensor(index);
|
||||
auto op = gMlu->addOp<GatherObj>(inputMlu, indexMlu, nullptr, 1);
|
||||
gMlu->dataMalloc();
|
||||
inputMlu->setData(IncrementalGenerator());
|
||||
indexMlu->copyin(vector<int>{0, 3, 1});
|
||||
bangRuntime->run(gMlu);
|
||||
|
||||
// copy output from MLU to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
EXPECT_TRUE(oCpu->equalData(
|
||||
vector<float>{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,21 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/reshape.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(Identity, NativeCpu) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto t1 = g->addTensor({2, 2, 3, 1}, DataType::Float32);
|
||||
auto op = g->addOp<IdentityObj>(t1, nullptr);
|
||||
g->dataMalloc();
|
||||
t1->setData(IncrementalGenerator());
|
||||
|
||||
runtime->run(g);
|
||||
EXPECT_TRUE(op->getOutput()->equalData(
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}));
|
||||
}
|
||||
} // namespace infini
|
Loading…
Reference in New Issue