diff --git a/include/core/tensor.h b/include/core/tensor.h index 9b9b4237..cab503b8 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -20,7 +20,7 @@ class TensorObj : public TensorBaseObj { size_t getBytes() const; Shape getDims() const { return shape; } - + vector getStride() const; size_t getOffset(const Shape &ds) const; using TensorBaseObj::getData; VType getData(const Shape &pos) const; diff --git a/include/cuda/cuda_pad_slice.h b/include/cuda/cuda_pad_slice.h new file mode 100644 index 00000000..9a452691 --- /dev/null +++ b/include/cuda/cuda_pad_slice.h @@ -0,0 +1,19 @@ +#pragma once + +const int MAX_DIM = 4; + +// Pad operator acts like padding small(part) tensor into a big(whole) tensor. +// Slice operator acts like spling a big(whole) tensor into a small(part) +// tensor. +typedef struct { + int begNum[MAX_DIM]; // pad or slice number at beginning + int wholeNDim[MAX_DIM]; // dim size after padding or before slicing + int partNDim[MAX_DIM]; // dim size before padding or after slicing + int partStride[MAX_DIM]; // stride before padding or after slicing +} TransMetaData; + +namespace infini { +void pad_slice_kernel(float *partData, float *wholeData, + const TransMetaData &metadata, int nDims, int num, + bool isPad); +} // namespace infini \ No newline at end of file diff --git a/include/operators/pad.h b/include/operators/pad.h new file mode 100644 index 00000000..2e80666a --- /dev/null +++ b/include/operators/pad.h @@ -0,0 +1,24 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class PadObj : public OperatorObj { + // the number of start and end pad values for all dims. + vector pads; + + public: + // pad for appointed axises,if axis is empty,then pad for all axises. + PadObj(GraphObj *graph, Tensor input, Tensor output, + const vector &pads, const optional> &axis); + + optional> inferShape(const TensorVec &inputs) const override; + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + Shape PadObj::getPads() const { return pads; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini \ No newline at end of file diff --git a/include/operators/slice.h b/include/operators/slice.h new file mode 100644 index 00000000..5c78752a --- /dev/null +++ b/include/operators/slice.h @@ -0,0 +1,24 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class SliceObj : public OperatorObj { + vector starts, ends; // the start no. and end no. for all dims. + + public: + SliceObj(GraphObj *graph, Tensor input, Tensor output, + const vector &starts, const vector &ends, + const optional> &axis, + const optional> &steps); + + optional> inferShape(const TensorVec &inputs) const override; + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + Shape getStart() const { return starts; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini \ No newline at end of file diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 7fe67207..0eaad451 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -30,6 +30,17 @@ size_t TensorObj::getOffset(const Shape &pos) const { return idx; } +vector TensorObj::getStride() const { + vector ret; + size_t stride = 1; + for (int i = shape.size() - 1; i >= 1; i--) { + ret.emplace(ret.begin(), stride); + stride *= shape.at(i); + } + ret.emplace(ret.begin(), stride); + return ret; +} + size_t TensorObj::size() const { size_t ret = 1; for (const auto &d : shape) diff --git a/src/kernels/cuda/pad_slice.cc b/src/kernels/cuda/pad_slice.cc new file mode 100644 index 00000000..04982a41 --- /dev/null +++ b/src/kernels/cuda/pad_slice.cc @@ -0,0 +1,45 @@ +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_pad_slice.h" +#include "operators/pad.h" +#include "operators/slice.h" +namespace infini { +class PadSliceCudaCompute { + public: + void do_compute(Tensor partTensor, Tensor wholeTensor, const Shape &begNos, + bool isPad) const { + int nDims = partTensor->getDims().size(); + IT_ASSERT(MAX_DIM >= nDims); + TransMetaData metadata; + for (int i = 0; i < nDims; i++) { + metadata.begNum[i] = begNos[i]; + metadata.wholeNDim[i] = wholeTensor->getDims()[i]; + metadata.partNDim[i] = partTensor->getDims()[i]; + metadata.partStride[i] = partTensor->getStride()[i]; + } + pad_slice_kernel(partTensor->getRawDataPtr(), + wholeTensor->getRawDataPtr(), metadata, nDims, + wholeTensor->size(), isPad); + } +}; + +class PadCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig { + void compute(const Operator &op, + const RuntimeObj *_context) const override { + do_compute(op->getInputs(0), op->getOutput(), as(op)->getPads(), + true); + } +}; + +class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig { + void compute(const Operator &op, + const RuntimeObj *_context) const override { + do_compute(op->getOutput(), op->getInputs(0), + as(op)->getStart(), false); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda, + "Slice__CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda, + "Pad__CUDA_Float32"); +} // namespace infini diff --git a/src/kernels/cuda/pad_slice.cu b/src/kernels/cuda/pad_slice.cu new file mode 100644 index 00000000..828aba3e --- /dev/null +++ b/src/kernels/cuda/pad_slice.cu @@ -0,0 +1,52 @@ +#include "cuda/cuda_common.h" +#include "cuda/cuda_pad_slice.h" + +__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset, + TransMetaData metaData, + int nDims) { + int offset = 0; + for (int i = nDims - 1; i >= 0; --i) { + auto wholePos = wholeOffset % metaData.wholeNDim[i]; + auto pos = wholePos - metaData.begNum[i]; + // if pos belongs to pad range, then return -1 + if (pos < 0 || pos >= metaData.partNDim[i]) + return -1; + wholeOffset = wholeOffset / metaData.wholeNDim[i]; + + offset += pos * metaData.partStride[i]; + } + + return offset; +} + +__global__ void _pad_slice_kernel(float *part, float *whole, + TransMetaData metaData, int nDims, int num, + bool isPad) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= num) + return; + + int stride = blockDim.x * gridDim.x; + while (tid < num) { + int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims); + if (isPad) + if (offset < 0) + whole[tid] = 0; + else + whole[tid] = part[offset]; + else + part[offset] = whole[tid]; + tid += stride; + } +} + +namespace infini { +void pad_slice_kernel(float *partData, float *wholeData, + const TransMetaData &metadata, int nDims, int num, + bool isPad) { + int blockSize = 32 * 16; + int gridSize = (num + blockSize - 1) / blockSize; + _pad_slice_kernel<<>>(partData, wholeData, metadata, + nDims, num, isPad); +} +} // namespace infini diff --git a/src/operators/pad.cc b/src/operators/pad.cc new file mode 100644 index 00000000..f3e219d6 --- /dev/null +++ b/src/operators/pad.cc @@ -0,0 +1,63 @@ +#include "operators/pad.h" + +namespace infini { +PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output, + const vector &_pads, + const optional> &axis) + : OperatorObj(OpType::Pad, {input}, {output}) { + if (axis == std::nullopt) + pads = _pads; + else { + int nAxis = (*axis).size(); + IT_ASSERT((int)_pads.size() == nAxis * 2); + int nDims = input->getDims().size(); + vector tmp(nDims * 2, 0); + + for (int i = 0; i < nAxis; ++i) { + tmp[(*axis)[i]] = _pads[i]; + tmp[(*axis)[i] + nDims] = _pads[i + nAxis]; + } + pads = tmp; + } + IT_ASSERT(checkValid(graph)); +} + +optional> PadObj::inferShape(const TensorVec &inputs) const { + auto dims = inputs[0]->getDims(); + int nDims = dims.size(); + if (nDims * 2 != (int)pads.size()) + return {}; + for (int i = 0; i < nDims; ++i) { + if (pads[i] < 0 || pads[i + nDims] < 0) + return {}; + dims[i] += pads[i] + pads[i + nDims]; + } + + return {{dims}}; +} +std::string PadObj::toString() const { + std::ostringstream os; + os << "Pad" + << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "pads=" << vecToString(pads) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector PadObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.insert(ret.end(), pads.begin(), pads.end()); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +vector PadObj::getOpAttrVector() const { + vector ret = pads; + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +} // namespace infini diff --git a/src/operators/slice.cc b/src/operators/slice.cc new file mode 100644 index 00000000..5987531f --- /dev/null +++ b/src/operators/slice.cc @@ -0,0 +1,80 @@ +#include "operators/slice.h" + +namespace infini { +SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output, + const vector &starts, const vector &ends, + const optional> &axis, + const optional> &steps) + : OperatorObj(OpType::Slice, {input}, {output}) { + if (steps != std::nullopt) + IT_TODO_HALT(); + IT_ASSERT(starts.size() == ends.size()); + + if (axis == std::nullopt) { + this->starts = starts; + this->ends = ends; + } else { + int nAxis = (*axis).size(); + IT_ASSERT((int)starts.size() == nAxis); + + int nDims = input->getDims().size(); + vector tmpS(nDims, 0), tmpE; + for (int i = 0; i < nDims; ++i) { + tmpE.emplace_back(input->getDims()[i] - 1); + } + + for (int i = 0; i < nAxis; ++i) { + if ((*axis)[i] < 0) + IT_TODO_HALT(); + tmpS[(*axis)[i]] = starts[i]; + tmpE[(*axis)[i]] = ends[i]; + } + this->starts = tmpS; + this->ends = tmpE; + } + IT_ASSERT(checkValid(graph)); +} + +optional> SliceObj::inferShape(const TensorVec &inputs) const { + auto dims = inputs[0]->getDims(); + int nDims = dims.size(); + if (nDims != (int)starts.size()) + return {}; + for (int i = 0; i < nDims; ++i) { + if (starts[i] < 0 || ends[i] >= dims[i] || starts[i] > ends[i]) + return {}; + dims[i] = ends[i] - starts[i] + 1; + } + + return {{dims}}; +} + +std::string SliceObj::toString() const { + std::ostringstream os; + os << "Slice" + << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "starts=" << vecToString(starts) << ","; + os << "ends=" << vecToString(ends) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector SliceObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.insert(ret.end(), starts.begin(), starts.end()); + ret.insert(ret.end(), ends.begin(), ends.end()); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +vector SliceObj::getOpAttrVector() const { + vector ret = starts; + ret.insert(ret.end(), ends.begin(), ends.end()); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_pad.cc b/test/kernels/cuda/test_cuda_pad.cc new file mode 100644 index 00000000..dfe3a188 --- /dev/null +++ b/test/kernels/cuda/test_cuda_pad.cc @@ -0,0 +1,41 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/pad.h" +#include "test.h" + +namespace infini { +TEST(Pad, Cuda) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = + make_ref(Shape{1, 2, 3, 2}, DataType::Float32, cpuRuntime); + icpu->dataMalloc(); + icpu->setData(IncrementalGenerator()); + + // Build CUDA graph; + Graph g = make_ref(cudaRuntime); + auto i = g->cloneTensor(icpu); + auto op = g->addOp(i, nullptr, vector{1, 0, 1, 1}, + vector{0, 3}); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto cpuo = o->clone(cpuRuntime); + // cudaPrintTensor(o); + // check results on CPU + EXPECT_TRUE(cpuo->equalData( + vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 2, 3, 0, 4, 5, 0, 6, 7, 0, 8, 9, 0, 10, 11, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); +} +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_slice.cc b/test/kernels/cuda/test_cuda_slice.cc new file mode 100644 index 00000000..3cd7da6c --- /dev/null +++ b/test/kernels/cuda/test_cuda_slice.cc @@ -0,0 +1,39 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/slice.h" +#include "test.h" + +namespace infini { +TEST(CUDA_Slice, run) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = + make_ref(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime); + icpu->dataMalloc(); + icpu->setData(IncrementalGenerator()); + + // Build CUDA graph; + Graph g = make_ref(cudaRuntime); + auto i = g->cloneTensor(icpu); + auto op = + g->addOp(i, nullptr, vector{1, 1}, vector{1, 4}, + vector{0, 3}, std::nullopt); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto cpuo = o->clone(cpuRuntime); + // cudaPrintTensor(o); + // check results on CPU + EXPECT_TRUE(cpuo->equalData(vector{11, 12, 13, 14, 16, 17, 18, 19})); +} +} // namespace infini \ No newline at end of file diff --git a/test/operators/test_pad.cc b/test/operators/test_pad.cc new file mode 100644 index 00000000..23c11afd --- /dev/null +++ b/test/operators/test_pad.cc @@ -0,0 +1,25 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/pad.h" +#include "test.h" + +namespace infini { +TEST(Pad, ShapeInference) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32); + auto op = g->addOp( + i, nullptr, vector{2, 10, 1, 5, 0, 10, 1, 5}, std::nullopt); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{3, 84, 164, 172})); + } + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32); + auto op = g->addOp(i, nullptr, vector{2, 10, 1, 5}, + vector{0, 3}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{4, 64, 162, 177})); + } +} + +} // namespace infini \ No newline at end of file diff --git a/test/operators/test_slice.cc b/test/operators/test_slice.cc new file mode 100644 index 00000000..0a9430fc --- /dev/null +++ b/test/operators/test_slice.cc @@ -0,0 +1,27 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/slice.h" +#include "test.h" + +namespace infini { +TEST(Slice, ShapeInference) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({10, 64, 162, 162}, DataType::UInt32); + auto op = g->addOp(i, nullptr, vector{2, 10, 1, 5}, + vector{3, 10, 100, 100}, std::nullopt, + std::nullopt); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 100, 96})); + } + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({10, 64, 162, 162}, DataType::UInt32); + auto op = g->addOp(i, nullptr, vector{2, 5}, + vector{3, 100}, vector{1, 3}, + std::nullopt); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{10, 2, 162, 96})); + } +} + +} // namespace infini \ No newline at end of file