forked from jiuyuan/InfiniTensor
Add GatherElements op and cuda kernel (#149)
* Add GatherElements op and cuda kernel * fix format * remove print * remove unused var * fix spacing * fix format --------- Co-authored-by: panzezhong@qiyuanlab.com <panzezhong@zezhongpan> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
ed3034f878
commit
36ae7b7fb6
|
@ -67,6 +67,7 @@ class GraphHandlerObj {
|
||||||
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
||||||
int num_outputs);
|
int num_outputs);
|
||||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||||
|
Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis);
|
||||||
Tensor reduceMean(Tensor data, Tensor reduced,
|
Tensor reduceMean(Tensor data, Tensor reduced,
|
||||||
const optional<vector<int>> &axes, bool keepdims);
|
const optional<vector<int>> &axes, bool keepdims);
|
||||||
Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
|
Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
|
||||||
|
|
|
@ -1,19 +1,60 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "core/data_type.h"
|
#include "core/data_type.h"
|
||||||
|
#include "core/operator.h"
|
||||||
|
#include "operators/gather.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
struct GatherMetaData {
|
struct GatherMetaData {
|
||||||
|
// Pointer to indices
|
||||||
void *indexValue;
|
void *indexValue;
|
||||||
|
// Type of index values
|
||||||
DataType indexType;
|
DataType indexType;
|
||||||
|
// Type of input and output data
|
||||||
|
DataType dataType;
|
||||||
|
// Axis of the gather operation
|
||||||
int axis;
|
int axis;
|
||||||
|
// Rank of input
|
||||||
int inNDim;
|
int inNDim;
|
||||||
|
// Rank of output
|
||||||
int outNDim;
|
int outNDim;
|
||||||
|
// Rank of indices
|
||||||
int idxNDim;
|
int idxNDim;
|
||||||
|
// Shape of output
|
||||||
int outDim[4];
|
int outDim[4];
|
||||||
|
// Shape of indices
|
||||||
int idxDim[4];
|
int idxDim[4];
|
||||||
|
// Strides of indices
|
||||||
int idxStride[4];
|
int idxStride[4];
|
||||||
|
// Strides of input
|
||||||
int inStride[4];
|
int inStride[4];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline void initGatherMetaData(GatherMetaData &metaData,
|
||||||
|
const Ref<OperatorObj> &_op) {
|
||||||
|
memset(&metaData, 0, sizeof(metaData));
|
||||||
|
auto op = as<GatherBaseObj>(_op);
|
||||||
|
Ref<TensorObj> in = op->getInputs(0);
|
||||||
|
Ref<TensorObj> index = op->getInputs(1);
|
||||||
|
Ref<TensorObj> out = op->getOutput();
|
||||||
|
metaData.indexValue = index->getRawDataPtr<void *>();
|
||||||
|
metaData.indexType = index->getDType();
|
||||||
|
metaData.dataType = in->getDType();
|
||||||
|
metaData.axis = op->getAxis();
|
||||||
|
metaData.inNDim = in->getRank();
|
||||||
|
metaData.outNDim = out->getRank();
|
||||||
|
metaData.idxNDim = index->getRank();
|
||||||
|
for (int i = 0; i < metaData.outNDim; ++i)
|
||||||
|
metaData.outDim[i] = out->getDims()[i];
|
||||||
|
for (int i = 0; i < metaData.idxNDim; ++i) {
|
||||||
|
metaData.idxDim[i] = index->getDims()[i];
|
||||||
|
metaData.idxStride[i] = index->getStride()[i];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < metaData.inNDim; ++i) {
|
||||||
|
metaData.inStride[i] = in->getStride()[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num);
|
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num);
|
||||||
|
|
||||||
|
void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
|
||||||
|
size_t num);
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -3,14 +3,28 @@
|
||||||
#include "core/operator.h"
|
#include "core/operator.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
class GatherBaseObj : public OperatorObj {
|
||||||
|
protected:
|
||||||
|
int axis;
|
||||||
|
|
||||||
|
public:
|
||||||
|
GatherBaseObj(OpType opType, TensorVec inputs, TensorVec outputs, int axis)
|
||||||
|
: OperatorObj(opType, inputs, outputs), axis(axis) {}
|
||||||
|
|
||||||
|
virtual ~GatherBaseObj() {}
|
||||||
|
int numInputs() const override { return 2; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
int getAxis() const { return axis; }
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Gather and concatenate given positions on a certain dimension of the
|
* @brief Gather and concatenate given positions on a certain dimension of the
|
||||||
* input tensor using an index tensor.
|
* input tensor using an index tensor.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
class GatherObj : public OperatorObj {
|
class GatherObj : public GatherBaseObj {
|
||||||
int axis;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new Gather object.
|
* @brief Construct a new Gather object.
|
||||||
|
@ -25,10 +39,7 @@ class GatherObj : public OperatorObj {
|
||||||
int axis);
|
int axis);
|
||||||
OP_CLONE(GatherObj);
|
OP_CLONE(GatherObj);
|
||||||
std::string toString() const override;
|
std::string toString() const override;
|
||||||
int numInputs() const override { return 2; }
|
|
||||||
int numOutputs() const override { return 1; }
|
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
int getAxis() const { return axis; }
|
|
||||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -36,4 +47,33 @@ class GatherObj : public OperatorObj {
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief GatherElements takes two inputs data and indices of the
|
||||||
|
* same rank r >= 1 and an optional attribute axis that identifies
|
||||||
|
* an axis of data.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class GatherElementsObj : public GatherBaseObj {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new GatherElements object.
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param input The input tensor.
|
||||||
|
* @param indices The index tensor.
|
||||||
|
* @param output The output tensor. Same shape as indices.
|
||||||
|
* @param axis The axis to gather on.
|
||||||
|
*/
|
||||||
|
GatherElementsObj(GraphObj *graph, Tensor input, Tensor indices,
|
||||||
|
Tensor output, int axis);
|
||||||
|
OP_CLONE(GatherElementsObj);
|
||||||
|
std::string toString() const override;
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -562,6 +562,16 @@ class OnnxStub:
|
||||||
0,
|
0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "GatherElements":
|
||||||
|
tensors[node.output[0]] = self.handler.gatherElements(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
next(
|
||||||
|
(attr.i for attr in node.attribute if attr.name == "axis"),
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
)
|
||||||
elif node.op_type == "ReduceMean":
|
elif node.op_type == "ReduceMean":
|
||||||
tensors[node.output[0]] = self.handler.reduce_mean(
|
tensors[node.output[0]] = self.handler.reduce_mean(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
|
|
@ -307,13 +307,22 @@ class TestStringMethods(unittest.TestCase):
|
||||||
|
|
||||||
def test_gather(self):
|
def test_gather(self):
|
||||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [1, 3, 4, 4])
|
data = make_tensor_value_info("data", TensorProto.FLOAT, [1, 3, 4, 4])
|
||||||
indices = make_tensor_value_info("indices", TensorProto.FLOAT, [2, 1, 2])
|
indices = make_tensor_value_info("indices", TensorProto.INT64, [2, 1, 2])
|
||||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 2, 4, 4])
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 2, 4, 4])
|
||||||
gather = make_node(
|
gather = make_node(
|
||||||
"Gather", ["data", "indices"], ["output"], axis=1, name="gather"
|
"Gather", ["data", "indices"], ["output"], axis=1, name="gather"
|
||||||
)
|
)
|
||||||
make_and_import_model(make_graph([gather], "gather", [data, indices], [output]))
|
make_and_import_model(make_graph([gather], "gather", [data, indices], [output]))
|
||||||
|
|
||||||
|
def test_gather_elements(self):
|
||||||
|
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 2])
|
||||||
|
indices = make_tensor_value_info("indices", TensorProto.INT64, [2, 1, 2])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [2, 1, 2])
|
||||||
|
gatherElements = make_node(
|
||||||
|
"GatherElements", ["data", "indices"], ["output"], axis=1, name="gatherElements"
|
||||||
|
)
|
||||||
|
make_and_import_model(make_graph([gatherElements], "gatherElements", [data, indices], [output]))
|
||||||
|
|
||||||
def test_reduce_mean(self):
|
def test_reduce_mean(self):
|
||||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
|
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
|
||||||
reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1])
|
reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1])
|
||||||
|
|
|
@ -265,6 +265,20 @@ Tensor GraphHandlerObj::gather(Tensor data, Tensor indices, Tensor output,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::gatherElements(Tensor data, Tensor indices,
|
||||||
|
Tensor output, int axis) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<GatherElementsObj>(
|
||||||
|
std::move(data), std::move(indices), output, axis);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<GatherElementsObj>(std::move(data), std::move(indices),
|
||||||
|
output, axis)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced,
|
Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced,
|
||||||
const optional<vector<int>> &axes,
|
const optional<vector<int>> &axes,
|
||||||
bool keepdims) {
|
bool keepdims) {
|
||||||
|
|
|
@ -85,6 +85,7 @@ void export_values(py::module &m) {
|
||||||
.VALUE(OpType, Div)
|
.VALUE(OpType, Div)
|
||||||
.VALUE(OpType, Pow)
|
.VALUE(OpType, Pow)
|
||||||
.VALUE(OpType, Gather)
|
.VALUE(OpType, Gather)
|
||||||
|
.VALUE(OpType, GatherElements)
|
||||||
.VALUE(OpType, ReduceMean)
|
.VALUE(OpType, ReduceMean)
|
||||||
.VALUE(OpType, Reshape)
|
.VALUE(OpType, Reshape)
|
||||||
.VALUE(OpType, Flatten)
|
.VALUE(OpType, Flatten)
|
||||||
|
@ -227,8 +228,9 @@ static int split_axis_of(Operator op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static int gather_axis_of(Operator op) {
|
static int gather_axis_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::Gather);
|
IT_ASSERT(op->getOpType() == OpType::Gather ||
|
||||||
return dynamic_cast<const GatherObj *>(op.get())->getAxis();
|
op->getOpType() == OpType::GatherElements);
|
||||||
|
return dynamic_cast<const GatherBaseObj *>(op.get())->getAxis();
|
||||||
}
|
}
|
||||||
|
|
||||||
static vector<int64_t> reshape_shape_of(Operator op) {
|
static vector<int64_t> reshape_shape_of(Operator op) {
|
||||||
|
@ -462,6 +464,7 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("concat", &Handler::concat, policy::move)
|
.def("concat", &Handler::concat, policy::move)
|
||||||
.def("split", &Handler::split, policy::move)
|
.def("split", &Handler::split, policy::move)
|
||||||
.def("gather", &Handler::gather, policy::move)
|
.def("gather", &Handler::gather, policy::move)
|
||||||
|
.def("gatherElements", &Handler::gatherElements, policy::move)
|
||||||
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
||||||
.def("slice", &Handler::slice, policy::move)
|
.def("slice", &Handler::slice, policy::move)
|
||||||
.def("pad", &Handler::pad, policy::move)
|
.def("pad", &Handler::pad, policy::move)
|
||||||
|
|
|
@ -5,29 +5,6 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
class GatherCuda : public CudaKernelWithoutConfig {
|
class GatherCuda : public CudaKernelWithoutConfig {
|
||||||
void initGatherMetaData(GatherMetaData &metaData,
|
|
||||||
const Operator &_op) const {
|
|
||||||
memset(&metaData, 0, sizeof(metaData));
|
|
||||||
auto op = as<GatherObj>(_op);
|
|
||||||
auto in = op->getInputs(0);
|
|
||||||
auto index = op->getInputs(1);
|
|
||||||
auto out = op->getOutput();
|
|
||||||
metaData.indexValue = index->getRawDataPtr<void *>();
|
|
||||||
metaData.indexType = index->getDType();
|
|
||||||
metaData.axis = op->getAxis();
|
|
||||||
metaData.inNDim = in->getRank();
|
|
||||||
metaData.outNDim = out->getRank();
|
|
||||||
metaData.idxNDim = index->getRank();
|
|
||||||
for (int i = 0; i < metaData.outNDim; ++i)
|
|
||||||
metaData.outDim[i] = out->getDims()[i];
|
|
||||||
for (int i = 0; i < metaData.idxNDim; ++i) {
|
|
||||||
metaData.idxDim[i] = index->getDims()[i];
|
|
||||||
metaData.idxStride[i] = index->getStride()[i];
|
|
||||||
}
|
|
||||||
for (int i = 0; i < metaData.inNDim; ++i) {
|
|
||||||
metaData.inStride[i] = in->getStride()[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void compute(const Operator &op,
|
void compute(const Operator &op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/gather.h"
|
||||||
|
#include "operators/gather.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class GatherElementsCuda : public CudaKernelWithoutConfig {
|
||||||
|
|
||||||
|
void compute(const Operator &op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
GatherMetaData metaData;
|
||||||
|
initGatherMetaData(metaData, op);
|
||||||
|
|
||||||
|
auto input = op->getInputs(0);
|
||||||
|
auto output = op->getOutput();
|
||||||
|
void *inData = input->getRawDataPtr<void *>();
|
||||||
|
void *outData = output->getRawDataPtr<void *>();
|
||||||
|
gather_elements_kernel(inData, outData, metaData,
|
||||||
|
op->getOutput()->size());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Float32,
|
||||||
|
GatherElementsCuda, "GatherELements_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Int32,
|
||||||
|
GatherElementsCuda, "GatherElements_CUDA_Int32");
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,65 @@
|
||||||
|
#include "cuda/cuda_common.h"
|
||||||
|
#include "cuda/gather.h"
|
||||||
|
|
||||||
|
template <typename Tind>
|
||||||
|
__device__ Tind tid2Offset(Tind tid, infini::GatherMetaData metaData) {
|
||||||
|
Tind offset = 0;
|
||||||
|
Tind gOffset = tid;
|
||||||
|
for (int i = metaData.inNDim - 1; i >= 0; --i) {
|
||||||
|
if (i == metaData.axis) {
|
||||||
|
Tind idx = static_cast<Tind *>(metaData.indexValue)[tid];
|
||||||
|
offset += idx * metaData.inStride[i];
|
||||||
|
} else {
|
||||||
|
Tind p = gOffset % metaData.idxDim[i];
|
||||||
|
offset += p * metaData.inStride[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
gOffset = gOffset / metaData.idxDim[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Tind>
|
||||||
|
__global__ void _gather_elements_kernel(T *in, T *out,
|
||||||
|
infini::GatherMetaData metaData,
|
||||||
|
size_t num) {
|
||||||
|
Tind tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
while (tid < num) {
|
||||||
|
Tind offset = tid2Offset<Tind>(tid, metaData);
|
||||||
|
out[tid] = in[offset];
|
||||||
|
tid += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
|
||||||
|
size_t num) {
|
||||||
|
int blockSize = 1024;
|
||||||
|
int gridSize = (num + blockSize - 1) / blockSize;
|
||||||
|
if (metaData.dataType == DataType::Float32 &&
|
||||||
|
metaData.indexType == DataType::Int64) {
|
||||||
|
_gather_elements_kernel<float, int64_t><<<gridSize, blockSize>>>(
|
||||||
|
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
|
||||||
|
metaData, num);
|
||||||
|
} else if (metaData.dataType == DataType::Int32 &&
|
||||||
|
metaData.indexType == DataType::Int64) {
|
||||||
|
_gather_elements_kernel<int, int64_t><<<gridSize, blockSize>>>(
|
||||||
|
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
|
||||||
|
num);
|
||||||
|
} else if (metaData.dataType == DataType::Float32 &&
|
||||||
|
metaData.indexType == DataType::Int32) {
|
||||||
|
_gather_elements_kernel<float, int><<<gridSize, blockSize>>>(
|
||||||
|
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
|
||||||
|
metaData, num);
|
||||||
|
} else if (metaData.dataType == DataType::Int32 &&
|
||||||
|
metaData.indexType == DataType::Int32) {
|
||||||
|
_gather_elements_kernel<int, int><<<gridSize, blockSize>>>(
|
||||||
|
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
|
||||||
|
num);
|
||||||
|
} else {
|
||||||
|
IT_TODO_HALT_MSG("GatherElements Cuda Kernel: Unsupported data type.\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -4,7 +4,7 @@
|
||||||
namespace infini {
|
namespace infini {
|
||||||
GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices,
|
GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices,
|
||||||
Tensor output, int axis)
|
Tensor output, int axis)
|
||||||
: OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) {
|
: GatherBaseObj(OpType::Gather, {input, indices}, {output}, axis) {
|
||||||
int rank = input->getRank();
|
int rank = input->getRank();
|
||||||
this->axis = get_real_axis(axis, rank);
|
this->axis = get_real_axis(axis, rank);
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
#include "operators/gather.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
GatherElementsObj::GatherElementsObj(GraphObj *graph, Tensor input,
|
||||||
|
Tensor indices, Tensor output, int axis)
|
||||||
|
: GatherBaseObj(OpType::GatherElements, {input, indices}, {output}, axis) {
|
||||||
|
int rank = input->getRank();
|
||||||
|
this->axis = get_real_axis(axis, rank);
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool checkShape(Tensor input, Tensor indices, int axis) {
|
||||||
|
auto inputDims = input->getDims();
|
||||||
|
auto indicesDims = indices->getDims();
|
||||||
|
if (input->getRank() != indices->getRank()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < static_cast<int>(input->getRank()); ++i) {
|
||||||
|
if (i != axis && inputDims[i] != indicesDims[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>>
|
||||||
|
GatherElementsObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
IT_ASSERT(checkShape(inputs[0], inputs[1], axis));
|
||||||
|
auto indicesDims = inputs[1]->getDims(); // output has same shape as indices
|
||||||
|
return {{indicesDims}};
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<DataType>
|
||||||
|
GatherElementsObj::inferDataType(const TensorVec &inputs) const {
|
||||||
|
IT_ASSERT(inputs.size() == 2);
|
||||||
|
auto indexDtype = inputs[1]->getDType();
|
||||||
|
IT_ASSERT(indexDtype == DataType::Int32 || indexDtype == DataType::Int64);
|
||||||
|
return {inputs[0]->getDType()};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GatherElementsObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "GatherElements"
|
||||||
|
<< "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
if (inputs.size() == 2) {
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << vecToString(inputs[1]->getDims()) << ",";
|
||||||
|
}
|
||||||
|
os << "axis=" << axis << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> GatherElementsObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
for (auto it : inputs[1]->getDims())
|
||||||
|
ret.emplace_back(it);
|
||||||
|
ret.emplace_back(axis);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> GatherElementsObj::getOpAttrVector() const {
|
||||||
|
return {type.underlying(), axis};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,43 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "cuda/gather.h"
|
||||||
|
#include "operators/gather.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(GatherElements, intDataLongIndices) {
|
||||||
|
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto inputCuda = gCuda->addTensor({3, 3}, DataType::Int32);
|
||||||
|
auto indexCuda = gCuda->addTensor({2, 3}, DataType::Int64);
|
||||||
|
auto op = gCuda->addOp<GatherElementsObj>(inputCuda, indexCuda, nullptr, 0);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputCuda->copyin(vector<int>{1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
|
indexCuda->copyin(vector<int64_t>{1, 2, 0, 2, 0, 0});
|
||||||
|
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
auto result = op->getOutput()->clone(cpuRuntime);
|
||||||
|
EXPECT_TRUE(result->equalData<int>({4, 8, 3, 7, 2, 3}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GatherElements, floatDataIntIndices) {
|
||||||
|
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto inputCuda = gCuda->addTensor({2, 2}, DataType::Float32);
|
||||||
|
auto indexCuda = gCuda->addTensor({2, 2}, DataType::Int32);
|
||||||
|
auto op = gCuda->addOp<GatherElementsObj>(inputCuda, indexCuda, nullptr, 1);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputCuda->copyin(vector<float>{1., 2., 3., 4.});
|
||||||
|
indexCuda->copyin(vector<int>{0, 0, 1, 0});
|
||||||
|
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
auto result = op->getOutput()->clone(cpuRuntime);
|
||||||
|
EXPECT_TRUE(result->equalData<float>({1., 1., 4., 3.}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,29 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/gather.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TEST(Gather, ShapeTypeInference) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({3, 3, 3}, DataType::Int32);
|
||||||
|
Tensor index = g->addTensor({2, 3, 3}, DataType::Int32);
|
||||||
|
auto op = g->addOp<GatherElementsObj>(i, index, nullptr, 0);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDType(), DataType::Int32);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 3}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({2, 4, 2}, DataType::Float32);
|
||||||
|
Tensor index = g->addTensor({2, 1, 2}, DataType::Int64);
|
||||||
|
auto op = g->addOp<GatherElementsObj>(i, index, nullptr, 1);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 2}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue