diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 6c670227..87e909f8 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -67,6 +67,7 @@ class GraphHandlerObj { TensorVec split(Tensor input, std::optional outputs, int axis, int num_outputs); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); + Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis); Tensor reduceMean(Tensor data, Tensor reduced, const optional> &axes, bool keepdims); Tensor slice(Tensor input, Tensor output, const vector &starts, diff --git a/include/cuda/gather.h b/include/cuda/gather.h index f3e0956a..0f0a1b27 100644 --- a/include/cuda/gather.h +++ b/include/cuda/gather.h @@ -1,19 +1,60 @@ #pragma once #include "core/data_type.h" +#include "core/operator.h" +#include "operators/gather.h" namespace infini { struct GatherMetaData { + // Pointer to indices void *indexValue; + // Type of index values DataType indexType; + // Type of input and output data + DataType dataType; + // Axis of the gather operation int axis; + // Rank of input int inNDim; + // Rank of output int outNDim; + // Rank of indices int idxNDim; + // Shape of output int outDim[4]; + // Shape of indices int idxDim[4]; + // Strides of indices int idxStride[4]; + // Strides of input int inStride[4]; }; +inline void initGatherMetaData(GatherMetaData &metaData, + const Ref &_op) { + memset(&metaData, 0, sizeof(metaData)); + auto op = as(_op); + Ref in = op->getInputs(0); + Ref index = op->getInputs(1); + Ref out = op->getOutput(); + metaData.indexValue = index->getRawDataPtr(); + metaData.indexType = index->getDType(); + metaData.dataType = in->getDType(); + metaData.axis = op->getAxis(); + metaData.inNDim = in->getRank(); + metaData.outNDim = out->getRank(); + metaData.idxNDim = index->getRank(); + for (int i = 0; i < metaData.outNDim; ++i) + metaData.outDim[i] = out->getDims()[i]; + for (int i = 0; i < metaData.idxNDim; ++i) { + metaData.idxDim[i] = index->getDims()[i]; + metaData.idxStride[i] = index->getStride()[i]; + } + for (int i = 0; i < metaData.inNDim; ++i) { + metaData.inStride[i] = in->getStride()[i]; + } +} void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num); + +void gather_elements_kernel(void *in, void *out, GatherMetaData metaData, + size_t num); } // namespace infini diff --git a/include/operators/gather.h b/include/operators/gather.h index d5d07a69..ff35aba8 100644 --- a/include/operators/gather.h +++ b/include/operators/gather.h @@ -3,14 +3,28 @@ #include "core/operator.h" namespace infini { + +class GatherBaseObj : public OperatorObj { + protected: + int axis; + + public: + GatherBaseObj(OpType opType, TensorVec inputs, TensorVec outputs, int axis) + : OperatorObj(opType, inputs, outputs), axis(axis) {} + + virtual ~GatherBaseObj() {} + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + + int getAxis() const { return axis; } +}; + /** * @brief Gather and concatenate given positions on a certain dimension of the * input tensor using an index tensor. * */ -class GatherObj : public OperatorObj { - int axis; - +class GatherObj : public GatherBaseObj { public: /** * @brief Construct a new Gather object. @@ -25,10 +39,7 @@ class GatherObj : public OperatorObj { int axis); OP_CLONE(GatherObj); std::string toString() const override; - int numInputs() const override { return 2; } - int numOutputs() const override { return 1; } optional> inferShape(const TensorVec &inputs) const override; - int getAxis() const { return axis; } vector inferDataType(const TensorVec &inputs) const override; private: @@ -36,4 +47,33 @@ class GatherObj : public OperatorObj { vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; + +/** + * @brief GatherElements takes two inputs data and indices of the + * same rank r >= 1 and an optional attribute axis that identifies + * an axis of data. + * + */ +class GatherElementsObj : public GatherBaseObj { + public: + /** + * @brief Construct a new GatherElements object. + * + * @param graph The computation graph that this operator belongs to. + * @param input The input tensor. + * @param indices The index tensor. + * @param output The output tensor. Same shape as indices. + * @param axis The axis to gather on. + */ + GatherElementsObj(GraphObj *graph, Tensor input, Tensor indices, + Tensor output, int axis); + OP_CLONE(GatherElementsObj); + std::string toString() const override; + optional> inferShape(const TensorVec &inputs) const override; + vector inferDataType(const TensorVec &inputs) const override; + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 659a9802..e4336dc4 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -562,6 +562,16 @@ class OnnxStub: 0, ), ) + elif node.op_type == "GatherElements": + tensors[node.output[0]] = self.handler.gatherElements( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis"), + 0, + ), + ) elif node.op_type == "ReduceMean": tensors[node.output[0]] = self.handler.reduce_mean( tensors[node.input[0]], diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index f80ad220..3808f516 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -307,13 +307,22 @@ class TestStringMethods(unittest.TestCase): def test_gather(self): data = make_tensor_value_info("data", TensorProto.FLOAT, [1, 3, 4, 4]) - indices = make_tensor_value_info("indices", TensorProto.FLOAT, [2, 1, 2]) + indices = make_tensor_value_info("indices", TensorProto.INT64, [2, 1, 2]) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 2, 4, 4]) gather = make_node( "Gather", ["data", "indices"], ["output"], axis=1, name="gather" ) make_and_import_model(make_graph([gather], "gather", [data, indices], [output])) + def test_gather_elements(self): + data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 2]) + indices = make_tensor_value_info("indices", TensorProto.INT64, [2, 1, 2]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [2, 1, 2]) + gatherElements = make_node( + "GatherElements", ["data", "indices"], ["output"], axis=1, name="gatherElements" + ) + make_and_import_model(make_graph([gatherElements], "gatherElements", [data, indices], [output])) + def test_reduce_mean(self): data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4]) reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 225fae09..77fbcf2d 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -265,6 +265,20 @@ Tensor GraphHandlerObj::gather(Tensor data, Tensor indices, Tensor output, } } +Tensor GraphHandlerObj::gatherElements(Tensor data, Tensor indices, + Tensor output, int axis) { + if (output) { + g->addOpWithOutputs( + std::move(data), std::move(indices), output, axis); + return output; + } else { + return g + ->addOp(std::move(data), std::move(indices), + output, axis) + ->getOutput(); + } +} + Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced, const optional> &axes, bool keepdims) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 8ac563b6..9881f92a 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -85,6 +85,7 @@ void export_values(py::module &m) { .VALUE(OpType, Div) .VALUE(OpType, Pow) .VALUE(OpType, Gather) + .VALUE(OpType, GatherElements) .VALUE(OpType, ReduceMean) .VALUE(OpType, Reshape) .VALUE(OpType, Flatten) @@ -227,8 +228,9 @@ static int split_axis_of(Operator op) { } static int gather_axis_of(Operator op) { - IT_ASSERT(op->getOpType() == OpType::Gather); - return dynamic_cast(op.get())->getAxis(); + IT_ASSERT(op->getOpType() == OpType::Gather || + op->getOpType() == OpType::GatherElements); + return dynamic_cast(op.get())->getAxis(); } static vector reshape_shape_of(Operator op) { @@ -462,6 +464,7 @@ void init_graph_builder(py::module &m) { .def("concat", &Handler::concat, policy::move) .def("split", &Handler::split, policy::move) .def("gather", &Handler::gather, policy::move) + .def("gatherElements", &Handler::gatherElements, policy::move) .def("reduce_mean", &Handler::reduceMean, policy::move) .def("slice", &Handler::slice, policy::move) .def("pad", &Handler::pad, policy::move) diff --git a/src/kernels/cuda/gather.cc b/src/kernels/cuda/gather.cc index e438db99..54e6bd10 100644 --- a/src/kernels/cuda/gather.cc +++ b/src/kernels/cuda/gather.cc @@ -5,29 +5,6 @@ namespace infini { class GatherCuda : public CudaKernelWithoutConfig { - void initGatherMetaData(GatherMetaData &metaData, - const Operator &_op) const { - memset(&metaData, 0, sizeof(metaData)); - auto op = as(_op); - auto in = op->getInputs(0); - auto index = op->getInputs(1); - auto out = op->getOutput(); - metaData.indexValue = index->getRawDataPtr(); - metaData.indexType = index->getDType(); - metaData.axis = op->getAxis(); - metaData.inNDim = in->getRank(); - metaData.outNDim = out->getRank(); - metaData.idxNDim = index->getRank(); - for (int i = 0; i < metaData.outNDim; ++i) - metaData.outDim[i] = out->getDims()[i]; - for (int i = 0; i < metaData.idxNDim; ++i) { - metaData.idxDim[i] = index->getDims()[i]; - metaData.idxStride[i] = index->getStride()[i]; - } - for (int i = 0; i < metaData.inNDim; ++i) { - metaData.inStride[i] = in->getStride()[i]; - } - } void compute(const Operator &op, const RuntimeObj *_context) const override { diff --git a/src/kernels/cuda/gather_elements.cc b/src/kernels/cuda/gather_elements.cc new file mode 100644 index 00000000..795a5c6f --- /dev/null +++ b/src/kernels/cuda/gather_elements.cc @@ -0,0 +1,28 @@ +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/gather.h" +#include "operators/gather.h" + +namespace infini { + +class GatherElementsCuda : public CudaKernelWithoutConfig { + + void compute(const Operator &op, + const RuntimeObj *_context) const override { + GatherMetaData metaData; + initGatherMetaData(metaData, op); + + auto input = op->getInputs(0); + auto output = op->getOutput(); + void *inData = input->getRawDataPtr(); + void *outData = output->getRawDataPtr(); + gather_elements_kernel(inData, outData, metaData, + op->getOutput()->size()); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Float32, + GatherElementsCuda, "GatherELements_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Int32, + GatherElementsCuda, "GatherElements_CUDA_Int32"); +} // namespace infini diff --git a/src/kernels/cuda/gather_elements.cu b/src/kernels/cuda/gather_elements.cu new file mode 100644 index 00000000..675a6b15 --- /dev/null +++ b/src/kernels/cuda/gather_elements.cu @@ -0,0 +1,65 @@ +#include "cuda/cuda_common.h" +#include "cuda/gather.h" + +template +__device__ Tind tid2Offset(Tind tid, infini::GatherMetaData metaData) { + Tind offset = 0; + Tind gOffset = tid; + for (int i = metaData.inNDim - 1; i >= 0; --i) { + if (i == metaData.axis) { + Tind idx = static_cast(metaData.indexValue)[tid]; + offset += idx * metaData.inStride[i]; + } else { + Tind p = gOffset % metaData.idxDim[i]; + offset += p * metaData.inStride[i]; + } + + gOffset = gOffset / metaData.idxDim[i]; + } + + return offset; +} + +template +__global__ void _gather_elements_kernel(T *in, T *out, + infini::GatherMetaData metaData, + size_t num) { + Tind tid = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + while (tid < num) { + Tind offset = tid2Offset(tid, metaData); + out[tid] = in[offset]; + tid += stride; + } +} + +namespace infini { +void gather_elements_kernel(void *in, void *out, GatherMetaData metaData, + size_t num) { + int blockSize = 1024; + int gridSize = (num + blockSize - 1) / blockSize; + if (metaData.dataType == DataType::Float32 && + metaData.indexType == DataType::Int64) { + _gather_elements_kernel<<>>( + reinterpret_cast(in), reinterpret_cast(out), + metaData, num); + } else if (metaData.dataType == DataType::Int32 && + metaData.indexType == DataType::Int64) { + _gather_elements_kernel<<>>( + reinterpret_cast(in), reinterpret_cast(out), metaData, + num); + } else if (metaData.dataType == DataType::Float32 && + metaData.indexType == DataType::Int32) { + _gather_elements_kernel<<>>( + reinterpret_cast(in), reinterpret_cast(out), + metaData, num); + } else if (metaData.dataType == DataType::Int32 && + metaData.indexType == DataType::Int32) { + _gather_elements_kernel<<>>( + reinterpret_cast(in), reinterpret_cast(out), metaData, + num); + } else { + IT_TODO_HALT_MSG("GatherElements Cuda Kernel: Unsupported data type.\n"); + } +} +} // namespace infini diff --git a/src/operators/gather.cc b/src/operators/gather.cc index 96493323..0cddca3c 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -4,7 +4,7 @@ namespace infini { GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices, Tensor output, int axis) - : OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) { + : GatherBaseObj(OpType::Gather, {input, indices}, {output}, axis) { int rank = input->getRank(); this->axis = get_real_axis(axis, rank); IT_ASSERT(checkValid(graph)); diff --git a/src/operators/gather_elements.cc b/src/operators/gather_elements.cc new file mode 100644 index 00000000..a1e6bffe --- /dev/null +++ b/src/operators/gather_elements.cc @@ -0,0 +1,70 @@ +#include "operators/gather.h" +#include "utils/operator_utils.h" + +namespace infini { +GatherElementsObj::GatherElementsObj(GraphObj *graph, Tensor input, + Tensor indices, Tensor output, int axis) + : GatherBaseObj(OpType::GatherElements, {input, indices}, {output}, axis) { + int rank = input->getRank(); + this->axis = get_real_axis(axis, rank); + IT_ASSERT(checkValid(graph)); +} + +bool checkShape(Tensor input, Tensor indices, int axis) { + auto inputDims = input->getDims(); + auto indicesDims = indices->getDims(); + if (input->getRank() != indices->getRank()) { + return false; + } + for (int i = 0; i < static_cast(input->getRank()); ++i) { + if (i != axis && inputDims[i] != indicesDims[i]) { + return false; + } + } + return true; +} + +optional> +GatherElementsObj::inferShape(const TensorVec &inputs) const { + IT_ASSERT(checkShape(inputs[0], inputs[1], axis)); + auto indicesDims = inputs[1]->getDims(); // output has same shape as indices + return {{indicesDims}}; +} + +vector +GatherElementsObj::inferDataType(const TensorVec &inputs) const { + IT_ASSERT(inputs.size() == 2); + auto indexDtype = inputs[1]->getDType(); + IT_ASSERT(indexDtype == DataType::Int32 || indexDtype == DataType::Int64); + return {inputs[0]->getDType()}; +} + +std::string GatherElementsObj::toString() const { + std::ostringstream os; + os << "GatherElements" + << "[" << getGuid() << "]"; + os << "("; + if (inputs.size() == 2) { + os << vecToString(inputs[0]->getDims()) << ","; + os << vecToString(inputs[1]->getDims()) << ","; + } + os << "axis=" << axis << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector GatherElementsObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), type.underlying()); + for (auto it : inputs[1]->getDims()) + ret.emplace_back(it); + ret.emplace_back(axis); + return ret; +} + +vector GatherElementsObj::getOpAttrVector() const { + return {type.underlying(), axis}; +} + +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_gather_elements.cc b/test/kernels/cuda/test_cuda_gather_elements.cc new file mode 100644 index 00000000..6cd2f410 --- /dev/null +++ b/test/kernels/cuda/test_cuda_gather_elements.cc @@ -0,0 +1,43 @@ +#include "core/graph.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "cuda/gather.h" +#include "operators/gather.h" + +#include "test.h" + +namespace infini { +TEST(GatherElements, intDataLongIndices) { + auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputCuda = gCuda->addTensor({3, 3}, DataType::Int32); + auto indexCuda = gCuda->addTensor({2, 3}, DataType::Int64); + auto op = gCuda->addOp(inputCuda, indexCuda, nullptr, 0); + gCuda->dataMalloc(); + inputCuda->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9}); + indexCuda->copyin(vector{1, 2, 0, 2, 0, 0}); + + cudaRuntime->run(gCuda); + auto result = op->getOutput()->clone(cpuRuntime); + EXPECT_TRUE(result->equalData({4, 8, 3, 7, 2, 3})); +} + +TEST(GatherElements, floatDataIntIndices) { + auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputCuda = gCuda->addTensor({2, 2}, DataType::Float32); + auto indexCuda = gCuda->addTensor({2, 2}, DataType::Int32); + auto op = gCuda->addOp(inputCuda, indexCuda, nullptr, 1); + gCuda->dataMalloc(); + inputCuda->copyin(vector{1., 2., 3., 4.}); + indexCuda->copyin(vector{0, 0, 1, 0}); + + cudaRuntime->run(gCuda); + auto result = op->getOutput()->clone(cpuRuntime); + EXPECT_TRUE(result->equalData({1., 1., 4., 3.})); +} +} // namespace infini diff --git a/test/operators/test_gather_elements.cc b/test/operators/test_gather_elements.cc new file mode 100644 index 00000000..a5b6188a --- /dev/null +++ b/test/operators/test_gather_elements.cc @@ -0,0 +1,29 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/gather.h" + +#include "test.h" + +namespace infini { + +TEST(Gather, ShapeTypeInference) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({3, 3, 3}, DataType::Int32); + Tensor index = g->addTensor({2, 3, 3}, DataType::Int32); + auto op = g->addOp(i, index, nullptr, 0); + EXPECT_EQ(op->getOutput()->getDType(), DataType::Int32); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 3})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 4, 2}, DataType::Float32); + Tensor index = g->addTensor({2, 1, 2}, DataType::Int64); + auto op = g->addOp(i, index, nullptr, 1); + EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 2})); + } +} +} // namespace infini