forked from jiuyuan/InfiniTensor
ADD: Gather operator and cuda kernel. (#41)
fix a memory leak. add tests. ADD gather cuda kernel. ADD gather operator Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
3c6e208f42
commit
fe14c91f54
|
@ -69,13 +69,13 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
||||||
// TODO: unify these copy APIs
|
// TODO: unify these copy APIs
|
||||||
virtual void copyBlobFromCPU(void *dst, const void *src,
|
virtual void copyBlobFromCPU(void *dst, const void *src,
|
||||||
size_t bytes) const = 0;
|
size_t bytes) const = 0;
|
||||||
|
virtual void copyBlobToCPU(void *dst, const void *src,
|
||||||
|
size_t bytes) const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void printProfilingData(double totTime,
|
void printProfilingData(double totTime,
|
||||||
const std::map<OpType, double> &opTime,
|
const std::map<OpType, double> &opTime,
|
||||||
const std::map<OpType, int> &opCnt) const;
|
const std::map<OpType, int> &opCnt) const;
|
||||||
virtual void copyBlobToCPU(void *dst, const void *src,
|
|
||||||
size_t bytes) const = 0;
|
|
||||||
virtual void copyBlobInsideRuntime(void *dst, const void *src,
|
virtual void copyBlobInsideRuntime(void *dst, const void *src,
|
||||||
size_t bytes) const = 0;
|
size_t bytes) const = 0;
|
||||||
};
|
};
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int *indexValue;
|
||||||
|
int axis;
|
||||||
|
int inNDim;
|
||||||
|
int outNDim;
|
||||||
|
int idxNDim;
|
||||||
|
int outDim[4];
|
||||||
|
int idxDim[4];
|
||||||
|
int idxStride[4];
|
||||||
|
int inStride[4];
|
||||||
|
} GatherMetaData;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void gather_kernel(float *in, float *out, GatherMetaData metaData, int num);
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class GatherObj : public OperatorObj {
|
||||||
|
int axis;
|
||||||
|
|
||||||
|
public:
|
||||||
|
GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output,
|
||||||
|
int axis);
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 2; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
int getAxis() const { return axis; }
|
||||||
|
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool CheckIndexValid() const;
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,48 @@
|
||||||
|
#include "operators/gather.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/gather.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void initGatherMetaData(GatherMetaData &metaData, const Operator &_op) {
|
||||||
|
memset(&metaData, 0, sizeof(metaData));
|
||||||
|
auto op = as<GatherObj>(_op);
|
||||||
|
auto in = op->getInputs(0);
|
||||||
|
auto index = op->getInputs(1);
|
||||||
|
auto out = op->getOutput();
|
||||||
|
metaData.indexValue = index->getRawDataPtr<int *>();
|
||||||
|
metaData.axis = op->getAxis();
|
||||||
|
metaData.inNDim = in->getDims().size();
|
||||||
|
metaData.outNDim = out->getDims().size();
|
||||||
|
metaData.idxNDim = index->getDims().size();
|
||||||
|
for (int i = 0; i < metaData.outNDim; ++i)
|
||||||
|
metaData.outDim[i] = out->getDims()[i];
|
||||||
|
for (int i = 0; i < metaData.idxNDim; ++i) {
|
||||||
|
metaData.idxDim[i] = index->getDims()[i];
|
||||||
|
metaData.idxStride[i] = index->getStride()[i];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < metaData.inNDim; ++i) {
|
||||||
|
metaData.inStride[i] = in->getStride()[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class GatherCuda : public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
|
||||||
|
auto input = op->getInputs(0);
|
||||||
|
auto index = op->getInputs(1);
|
||||||
|
|
||||||
|
GatherMetaData metaData;
|
||||||
|
initGatherMetaData(metaData, op);
|
||||||
|
|
||||||
|
auto inData = input->getRawDataPtr<float *>();
|
||||||
|
auto outData = op->getOutput()->getRawDataPtr<float *>();
|
||||||
|
gather_kernel(inData, outData, metaData, op->getOutput()->size());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Gather, DataType::Float32, GatherCuda,
|
||||||
|
"Gather_CUDA_Float32");
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,47 @@
|
||||||
|
#include "cuda/cuda_common.h"
|
||||||
|
#include "cuda/gather.h"
|
||||||
|
|
||||||
|
__device__ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) {
|
||||||
|
int offset = 0;
|
||||||
|
for (int i = metaData.inNDim - 1, k = metaData.outNDim - 1; i >= 0; --i) {
|
||||||
|
int idx = 0;
|
||||||
|
if (i == metaData.axis) {
|
||||||
|
int idxOffset = 0;
|
||||||
|
for (int j = metaData.idxNDim - 1; j >= 0; --j) {
|
||||||
|
int p = gOffset % metaData.idxDim[j];
|
||||||
|
gOffset = gOffset / metaData.idxDim[j];
|
||||||
|
idxOffset += p * metaData.idxStride[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
idx = metaData.indexValue[idxOffset];
|
||||||
|
k = k - metaData.idxNDim;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
idx = gOffset % metaData.outDim[k];
|
||||||
|
gOffset = gOffset / metaData.outDim[k];
|
||||||
|
--k;
|
||||||
|
}
|
||||||
|
offset += idx * metaData.inStride[i];
|
||||||
|
}
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void _gather_kernel(float *in, float *out, GatherMetaData metaData,
|
||||||
|
int num) {
|
||||||
|
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
while (tid < num) {
|
||||||
|
int offset = gatheredOffset2Offset(tid, metaData);
|
||||||
|
out[tid] = in[offset];
|
||||||
|
tid += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void gather_kernel(float *in, float *out, GatherMetaData metaData, int num) {
|
||||||
|
int blockSize = 32 * 16;
|
||||||
|
int gridSize = (num + blockSize - 1) / blockSize;
|
||||||
|
|
||||||
|
_gather_kernel<<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,85 @@
|
||||||
|
#include "operators/gather.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output,
|
||||||
|
int axis)
|
||||||
|
: OperatorObj(OpType::Gather, {input, index}, {output}), axis(axis) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
auto dims0 = inputs[0]->getDims();
|
||||||
|
auto dims1 = inputs[1]->getDims();
|
||||||
|
|
||||||
|
if (axis < 0)
|
||||||
|
IT_TODO_HALT();
|
||||||
|
|
||||||
|
if ((size_t)axis >= dims0.size())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
IT_ASSERT(CheckIndexValid());
|
||||||
|
|
||||||
|
Shape dim = dims0;
|
||||||
|
dim.erase(dim.begin() + axis);
|
||||||
|
dim.insert(dim.begin() + axis, dims1.begin(), dims1.end());
|
||||||
|
return {{dim}};
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
|
||||||
|
IT_ASSERT(inputs.size() == 2);
|
||||||
|
auto index = inputs[1];
|
||||||
|
IT_ASSERT(index->getDType() == DataType::UInt32);
|
||||||
|
return {inputs[0]->getDType()};
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO:should check everytime index updated.
|
||||||
|
bool GatherObj::CheckIndexValid() const {
|
||||||
|
auto index = inputs[1];
|
||||||
|
if (index->getDataBlob() == nullptr)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
|
int *data = (int *)runtime->alloc(index->getBytes());
|
||||||
|
index->getRuntime()->copyBlobToCPU(
|
||||||
|
(void *)data, index->getRawDataPtr<void *>(), index->getBytes());
|
||||||
|
|
||||||
|
bool ret = true;
|
||||||
|
auto value = inputs[0]->getDims()[axis];
|
||||||
|
for (size_t i = 0; i < index->size(); ++i) {
|
||||||
|
if (data[i] < 0 || data[i] >= value) {
|
||||||
|
ret = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
runtime->dealloc(data);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GatherObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Gather"
|
||||||
|
<< "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
if (inputs.size() == 2) {
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << vecToString(inputs[1]->getDims()) << ",";
|
||||||
|
}
|
||||||
|
os << "axis=" << axis << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
vector<int> GatherObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||||
|
for (auto it : inputs[1]->getDims())
|
||||||
|
ret.emplace_back(it);
|
||||||
|
ret.emplace_back(axis);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> GatherObj::getOpAttrVector() const {
|
||||||
|
return {enum_to_underlying(type), axis};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,244 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "cuda/gather.h"
|
||||||
|
#include "operators/gather.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
namespace infini {
|
||||||
|
/*
|
||||||
|
test1:
|
||||||
|
input = [
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
]
|
||||||
|
indices = [
|
||||||
|
[0, 1],
|
||||||
|
[1, 2],
|
||||||
|
]
|
||||||
|
output = [
|
||||||
|
[
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
axis=0
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
test2
|
||||||
|
input = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[6, 7, 8],
|
||||||
|
]
|
||||||
|
indices = [
|
||||||
|
[0, 2],
|
||||||
|
]
|
||||||
|
axis = 1,
|
||||||
|
output = [
|
||||||
|
[[0, 2]],
|
||||||
|
[[3, 5]],
|
||||||
|
[[6, 8]],
|
||||||
|
]
|
||||||
|
*/
|
||||||
|
/*
|
||||||
|
test3
|
||||||
|
input=[[[ 0, 1],
|
||||||
|
[ 2, 3],
|
||||||
|
[ 4, 5],
|
||||||
|
[ 6, 7]],
|
||||||
|
|
||||||
|
[[ 8, 9],
|
||||||
|
[10, 11],
|
||||||
|
[12, 13],
|
||||||
|
[14, 15]]] //(2,4,2)
|
||||||
|
indices=[[0],[3],[1]] //(3,1)
|
||||||
|
axis=1
|
||||||
|
output=
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) {
|
||||||
|
int offset = 0;
|
||||||
|
for (int i = metaData.inNDim - 1, k = metaData.outNDim - 1; i >= 0; --i) {
|
||||||
|
int idx = 0;
|
||||||
|
if (i == metaData.axis) {
|
||||||
|
int idxOffset = 0;
|
||||||
|
for (int j = metaData.idxNDim - 1; j >= 0; --j) {
|
||||||
|
int p = gOffset % metaData.idxDim[j];
|
||||||
|
gOffset = gOffset / metaData.idxDim[j];
|
||||||
|
idxOffset += p * metaData.idxStride[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
idx = metaData.indexValue[idxOffset];
|
||||||
|
k = k - metaData.idxNDim;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
idx = gOffset % metaData.outDim[k];
|
||||||
|
gOffset = gOffset / metaData.outDim[k];
|
||||||
|
--k;
|
||||||
|
}
|
||||||
|
offset += idx * metaData.inStride[i];
|
||||||
|
}
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Gather, offsetTrans) {
|
||||||
|
{
|
||||||
|
GatherMetaData meta;
|
||||||
|
int data[] = {0, 1, 1, 2};
|
||||||
|
meta.indexValue = data;
|
||||||
|
meta.axis = 0;
|
||||||
|
meta.inNDim = 2;
|
||||||
|
meta.outNDim = 3;
|
||||||
|
meta.idxNDim = 2;
|
||||||
|
int tmp[] = {2, 2, 2, 0};
|
||||||
|
memcpy(&meta.outDim, &tmp, sizeof(tmp));
|
||||||
|
int tmp2[] = {2, 2, 0, 0};
|
||||||
|
memcpy(&meta.idxDim, &tmp2, sizeof(tmp));
|
||||||
|
int tmp3[] = {2, 1, 0, 0};
|
||||||
|
memcpy(&meta.idxStride, &tmp3, sizeof(tmp));
|
||||||
|
memcpy(&meta.inStride, &tmp3, sizeof(tmp));
|
||||||
|
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(0, meta), 0);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(1, meta), 1);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(2, meta), 2);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(3, meta), 3);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(4, meta), 2);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(5, meta), 3);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(6, meta), 4);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(7, meta), 5);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
GatherMetaData meta;
|
||||||
|
int data[] = {0, 2};
|
||||||
|
meta.indexValue = data;
|
||||||
|
meta.axis = 1;
|
||||||
|
meta.inNDim = 2;
|
||||||
|
meta.outNDim = 3;
|
||||||
|
meta.idxNDim = 2;
|
||||||
|
|
||||||
|
int tmp[] = {3, 1, 2, 0};
|
||||||
|
memcpy(&meta.outDim, &tmp, sizeof(tmp));
|
||||||
|
int tmp2[] = {1, 2, 0, 0};
|
||||||
|
memcpy(&meta.idxDim, &tmp2, sizeof(tmp2));
|
||||||
|
int tmp3[] = {2, 1, 0, 0};
|
||||||
|
memcpy(&meta.idxStride, &tmp3, sizeof(tmp3));
|
||||||
|
int tmp4[] = {3, 1, 0, 0};
|
||||||
|
memcpy(&meta.inStride, &tmp4, sizeof(tmp4));
|
||||||
|
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(0, meta), 0);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(1, meta), 2);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(2, meta), 3);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(3, meta), 5);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(4, meta), 6);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(5, meta), 8);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
GatherMetaData meta;
|
||||||
|
int data[] = {0, 3, 1};
|
||||||
|
meta.indexValue = data;
|
||||||
|
meta.axis = 1;
|
||||||
|
meta.inNDim = 3;
|
||||||
|
meta.outNDim = 4;
|
||||||
|
meta.idxNDim = 2;
|
||||||
|
|
||||||
|
int tmp[] = {2, 3, 1, 2};
|
||||||
|
memcpy(&meta.outDim, &tmp, sizeof(tmp));
|
||||||
|
int tmp2[] = {3, 1, 0, 0};
|
||||||
|
memcpy(&meta.idxDim, &tmp2, sizeof(tmp2));
|
||||||
|
int tmp3[] = {1, 1, 0, 0};
|
||||||
|
memcpy(&meta.idxStride, &tmp3, sizeof(tmp3));
|
||||||
|
int tmp4[] = {8, 2, 1, 0};
|
||||||
|
memcpy(&meta.inStride, &tmp4, sizeof(tmp4));
|
||||||
|
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(0, meta), 0);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(1, meta), 1);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(2, meta), 6);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(3, meta), 7);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(4, meta), 2);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(5, meta), 3);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(6, meta), 8);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(7, meta), 9);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(8, meta), 14);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(9, meta), 15);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(10, meta), 10);
|
||||||
|
EXPECT_EQ(gatheredOffset2Offset(11, meta), 11);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Gather, Cuda) {
|
||||||
|
{
|
||||||
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
auto input = gCpu->addTensor({3, 2}, DataType::Float32);
|
||||||
|
auto index = gCpu->addTensor({2, 2}, DataType::UInt32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->copyData(vector<float>{1, 2, 3, 4, 5, 6});
|
||||||
|
index->copyData(vector<uint32_t>{0, 1, 1, 2});
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto op = gCuda->addOp<GatherObj>(
|
||||||
|
gCuda->cloneTensor(input), gCuda->cloneTensor(index), nullptr, 0);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
// cudaPrintTensor(op->getOutput());
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||||
|
EXPECT_TRUE(oCpu->equalData(vector<float>{1, 2, 3, 4, 3, 4, 5, 6}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
auto input = gCpu->addTensor({3, 3}, DataType::Float32);
|
||||||
|
auto index = gCpu->addTensor({1, 2}, DataType::UInt32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->setData(IncrementalGenerator());
|
||||||
|
index->copyData(vector<uint32_t>{0, 2});
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto op = gCuda->addOp<GatherObj>(
|
||||||
|
gCuda->cloneTensor(input), gCuda->cloneTensor(index), nullptr, 1);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
// cudaPrintTensor(op->getOutput());
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||||
|
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 2, 3, 5, 6, 8}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32);
|
||||||
|
auto index = gCpu->addTensor({3, 1}, DataType::UInt32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->setData(IncrementalGenerator());
|
||||||
|
index->copyData(vector<uint32_t>{0, 3, 1});
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto op = gCuda->addOp<GatherObj>(
|
||||||
|
gCuda->cloneTensor(input), gCuda->cloneTensor(index), nullptr, 1);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
// cudaPrintTensor(op->getOutput());
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||||
|
EXPECT_TRUE(oCpu->equalData(
|
||||||
|
vector<float>{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,19 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/gather.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TEST(Gather, ShapeInference) {
|
||||||
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
|
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({1, 3, 4, 4}, DataType::UInt32);
|
||||||
|
Tensor index = g->addTensor({2, 1, 2}, DataType::UInt32);
|
||||||
|
auto op = g->addOp<GatherObj>(i, index, nullptr, 1);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue