diff --git a/include/operators/extend.h b/include/operators/extend.h new file mode 100644 index 00000000..d3ef64fe --- /dev/null +++ b/include/operators/extend.h @@ -0,0 +1,23 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class ExtendObj : public OperatorObj { + int dim, num; // copy num times at the dim. + + public: + ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim, + int num = 1); + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + int getDim() const { return dim; } + int getNum() const { return num; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/extend.cc b/src/kernels/cuda/extend.cc new file mode 100644 index 00000000..a5603e02 --- /dev/null +++ b/src/kernels/cuda/extend.cc @@ -0,0 +1,27 @@ +#include "operators/extend.h" +#include "cuda/cuda_kernel_wihtout_config.h" + +namespace infini { +void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter, + int oSize); +class ExtendCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto inData = op->getInputs(0)->getRawDataPtr(); + auto outData = op->getOutputs()[0]->getRawDataPtr(); + int blockSize = 1; + auto iDim = op->getInputs(0)->getDims(); + for (size_t i = iDim.size() - 1; + i >= (size_t)op->getDim() && i != (size_t)-1; --i) + blockSize *= iDim[i]; + auto blockSizeOuter = (op->getNum() + 1) * blockSize; + + extend_kernel(inData, outData, blockSize, blockSizeOuter, + op->getOutput()->size()); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Extend, DataType::Float32, ExtendCuda, + "Extend_CUDA_Float32"); +} // namespace infini diff --git a/src/kernels/cuda/extend.cu b/src/kernels/cuda/extend.cu new file mode 100644 index 00000000..05cf95cb --- /dev/null +++ b/src/kernels/cuda/extend.cu @@ -0,0 +1,25 @@ +#include "cuda/cuda_common.h" + +__global__ void _extend_kernel(float *in, float *out, int blockSize, + int blockSizeOuter, int oSize) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index >= oSize) + return; + + int stride = blockDim.x * gridDim.x; + while (index < oSize) { + auto iIdx = index % blockSize + index / blockSizeOuter * blockSize; + out[index] = in[iIdx]; + index += stride; + } +} + +namespace infini { +void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter, + int oSize) { + int blocksize = 32 * 16; + int gridsize = (oSize + blocksize - 1) / blocksize; + _extend_kernel<<>>(in, out, blockSize, blockSizeOuter, + oSize); +} +} // namespace infini \ No newline at end of file diff --git a/src/operators/extend.cc b/src/operators/extend.cc new file mode 100644 index 00000000..55ef9021 --- /dev/null +++ b/src/operators/extend.cc @@ -0,0 +1,41 @@ +#include "operators/extend.h" + +namespace infini { + +ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim, + int num) + : OperatorObj(OpType::Extend, {input}, {output}), dim(dim), num(num) { + IT_ASSERT(checkValid(graph)); +} + +optional> ExtendObj::inferShape(const TensorVec &inputs) const { + auto ret = inputs[0]->getDims(); + IT_ASSERT((size_t)dim < ret.size()); + ret[dim] = ret[dim] * (num + 1); + return {{ret}}; +} +std::string ExtendObj::toString() const { + std::ostringstream os; + os << "Extend[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "dim=" << dim << ","; + os << "num=" << num << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector ExtendObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace_back(dim); + ret.emplace_back(num); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +vector ExtendObj::getOpAttrVector() const { + return {enum_to_underlying(type), dim, num}; +} + +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_extend.cc b/test/kernels/cuda/test_cuda_extend.cc new file mode 100644 index 00000000..75167649 --- /dev/null +++ b/test/kernels/cuda/test_cuda_extend.cc @@ -0,0 +1,43 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/extend.h" + +#include "test.h" + +namespace infini { + +TEST(CUDA_Extend, run) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = + make_ref(Shape{2, 3, 2, 2}, DataType::Float32, cpuRuntime); + icpu->dataMalloc(); + icpu->setData(IncrementalGenerator()); + + // Build CUDA graph + Graph g = make_ref(cudaRuntime); + auto i = g->cloneTensor(icpu); + auto op = g->addOp(i, nullptr, 1, 1); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto ocpu = o->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(ocpu->equalData(vector{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23})); +} +} // namespace infini \ No newline at end of file diff --git a/test/operators/test_extend.cc b/test/operators/test_extend.cc new file mode 100644 index 00000000..c9c079fd --- /dev/null +++ b/test/operators/test_extend.cc @@ -0,0 +1,20 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/extend.h" + +#include "test.h" + +namespace infini { + +TEST(Extend, ShapeInference) { + Runtime runtime = CpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, 2, 1); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 6, 4})); + } +} + +} // namespace infini \ No newline at end of file