diff --git a/include/cuda/cuda_expand.h b/include/cuda/cuda_expand.h index b53c4ce4..8d4701fd 100644 --- a/include/cuda/cuda_expand.h +++ b/include/cuda/cuda_expand.h @@ -3,7 +3,7 @@ #include "operators/unary.h" #include "utils/small_array.h" namespace infini { -void expand_kernel(float *input, float *output, int nDims, int outputsize, - SmallArray inputShape, SmallArray outputShape); +void expandKernel(float *input, float *output, int nDims, int outputsize, + SmallArray inputShape, SmallArray outputShape); }; // namespace infini diff --git a/include/cuda/cuda_where.h b/include/cuda/cuda_where.h index 14d9bc73..15ad29ec 100644 --- a/include/cuda/cuda_where.h +++ b/include/cuda/cuda_where.h @@ -3,11 +3,9 @@ #include "utils/small_array.h" namespace infini { -void where_kernel(const float *inputx, const float *inputy, - const float *condition, float *output, int nDims, - infini::SmallArray inputxShape, - infini::SmallArray inputyShape, - infini::SmallArray conditionShape, - infini::SmallArray outputShape); +void whereKernel(const float *inputX, const float *inputY, + const uint8_t *condition, float *output, int nDims, + SmallArray inputXShape, SmallArray inputYShape, + SmallArray conditionShape, SmallArray outputShape); }; // namespace infini diff --git a/include/utils/broadcast_shape.h b/include/utils/broadcast_shape.h new file mode 100644 index 00000000..1f45ddcc --- /dev/null +++ b/include/utils/broadcast_shape.h @@ -0,0 +1,14 @@ +#pragma once + +namespace infini { +void broadcastShape(const Shape &originShape, SmallArray &modifyShape, + int nDims, int size) { + for (int i = nDims - 1; i >= 0; --i) { + modifyShape.data[i] = 1; + } + for (int i = size - 1; i >= 0; --i) { + modifyShape.data[i + nDims - size] = originShape[i]; + } +} + +} // namespace infini diff --git a/src/kernels/cuda/expand.cc b/src/kernels/cuda/expand.cc index b8154d49..acbf5cd2 100644 --- a/src/kernels/cuda/expand.cc +++ b/src/kernels/cuda/expand.cc @@ -25,8 +25,8 @@ class ExpandCuda : public CudaKernelWithoutConfig { inputShape.data[i] = in_Shape[i]; outputsize *= out_Shape[i]; } - expand_kernel((float *)inputData, (float *)outputData, nDims, - outputsize, inputShape, outputShape); + expandKernel((float *)inputData, (float *)outputData, nDims, outputsize, + inputShape, outputShape); } }; diff --git a/src/kernels/cuda/expand.cu b/src/kernels/cuda/expand.cu index e1649b81..09405d09 100644 --- a/src/kernels/cuda/expand.cu +++ b/src/kernels/cuda/expand.cu @@ -6,9 +6,9 @@ constexpr unsigned int num_threads() { return 32 * 4; } constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } -__global__ void _expand_kernel(float *input, float *output, int nDims, - int outputsize, infini::SmallArray inputShape, - infini::SmallArray outputShape) { +__global__ void _expandKernel(float *input, float *output, int nDims, + int outputsize, infini::SmallArray inputShape, + infini::SmallArray outputShape) { int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; // i(JKS) + j(KS) + k(S) + s @@ -38,12 +38,12 @@ __global__ void _expand_kernel(float *input, float *output, int nDims, } namespace infini { -void expand_kernel(float *input, float *output, int nDims, int outputsize, - SmallArray inputShape, SmallArray outputShape) { +void expandKernel(float *input, float *output, int nDims, int outputsize, + SmallArray inputShape, SmallArray outputShape) { int blocksize = block_work_size(); int gridsize = (outputsize + block_work_size() - 1) / block_work_size(); - _expand_kernel<<>>(input, output, nDims, outputsize, - inputShape, outputShape); + _expandKernel<<>>(input, output, nDims, outputsize, + inputShape, outputShape); } } // namespace infini diff --git a/src/kernels/cuda/where.cc b/src/kernels/cuda/where.cc index 4769fea0..9898ab7d 100644 --- a/src/kernels/cuda/where.cc +++ b/src/kernels/cuda/where.cc @@ -2,6 +2,7 @@ #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" #include "cuda/cuda_where.h" +#include "utils/broadcast_shape.h" namespace infini { @@ -10,28 +11,33 @@ class WhereCuda : public CudaKernelWithoutConfig { const RuntimeObj *_context) const override { auto op = as(_op); - void *const inputxData = (op->getInputs(0)->getRawDataPtr()); - void *const inputyData = (op->getInputs(1)->getRawDataPtr()); + void *const inputXData = (op->getInputs(0)->getRawDataPtr()); + void *const inputYData = (op->getInputs(1)->getRawDataPtr()); void *const conditionData = (op->getInputs(2)->getRawDataPtr()); void *const outputData = (op->getOutput()->getRawDataPtr()); - const auto &inputx_Shape = op->getInputs(0)->getDims(); - const auto &inputy_Shape = op->getInputs(1)->getDims(); - const auto &condition_Shape = op->getInputs(2)->getDims(); - const auto &output_Shape = op->getOutput()->getDims(); + const auto &opInputXShape = op->getInputs(0)->getDims(); + const auto &opInputYShape = op->getInputs(1)->getDims(); + const auto &opConditionShape = op->getInputs(2)->getDims(); + const auto &opOutputShape = op->getOutput()->getDims(); - int nDims = op->getInputs(0)->getDims().size(); + const int xSize = op->getInputs(0)->getRank(); + const int ySize = op->getInputs(1)->getRank(); + const int cSize = op->getInputs(2)->getRank(); + int nDims = op->getOutput()->getDims().size(); IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); - SmallArray inputxShape, inputyShape, conditionShape, outputShape; - for (int i = 0; i < nDims; ++i) { - inputxShape.data[i] = inputx_Shape[i]; - inputyShape.data[i] = inputy_Shape[i]; - conditionShape.data[i] = condition_Shape[i]; - outputShape.data[i] = output_Shape[i]; + SmallArray inputXShape, inputYShape, conditionShape, outputShape; + for (int i = nDims - 1; i >= 0; --i) { + outputShape.data[i] = opOutputShape[i]; } - where_kernel((float *)inputxData, (float *)inputyData, - (float *)conditionData, (float *)outputData, nDims, - inputxShape, inputyShape, conditionShape, outputShape); + + broadcastShape(opInputXShape, inputXShape, nDims, xSize); + broadcastShape(opInputYShape, inputYShape, nDims, ySize); + broadcastShape(opConditionShape, conditionShape, nDims, cSize); + + whereKernel((float *)inputXData, (float *)inputYData, + (uint8_t *)conditionData, (float *)outputData, nDims, + inputXShape, inputYShape, conditionShape, outputShape); } }; diff --git a/src/kernels/cuda/where.cu b/src/kernels/cuda/where.cu index 7d34098c..ce6579f8 100644 --- a/src/kernels/cuda/where.cu +++ b/src/kernels/cuda/where.cu @@ -1,20 +1,20 @@ #include "cuda/cuda_common.h" #include "utils/small_array.h" -__global__ void _where_kernel(const float *inputx, const float *inputy, - const float *condition, float *output, int nDims, - int outputsize, infini::SmallArray inputxShape, - infini::SmallArray inputyShape, - infini::SmallArray conditionShape, - infini::SmallArray outputShape) { +__global__ void _whereKernel(const float *inputX, const float *inputY, + const uint8_t *condition, float *output, int nDims, + int outputsize, infini::SmallArray inputXShape, + infini::SmallArray inputYShape, + infini::SmallArray conditionShape, + infini::SmallArray outputShape) { int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; if (outputIdx < outputsize) { - int inputxIdx = 0; - int temp_inputx = 1; + int inputXIdx = 0; + int temp_inputX = 1; - int inputyIdx = 0; - int temp_inputy = 1; + int inputYIdx = 0; + int temp_inputY = 1; int conditionIdx = 0; int temp_condition = 1; @@ -27,23 +27,23 @@ __global__ void _where_kernel(const float *inputx, const float *inputy, } else { tmp = v % outputShape.data[i]; // store s,k,j in order } - if (inputxShape.data[i] == 1) { - inputxIdx += 0; + if (inputXShape.data[i] == 1) { + inputXIdx += 0; } else { - inputxIdx += + inputXIdx += tmp * - temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s + temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s } - temp_inputx *= inputxShape.data[i]; + temp_inputX *= inputXShape.data[i]; //---------------------------- - if (inputyShape.data[i] == 1) { - inputyIdx += 0; + if (inputYShape.data[i] == 1) { + inputYIdx += 0; } else { - inputyIdx += + inputYIdx += tmp * - temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s + temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s } - temp_inputy *= inputyShape.data[i]; + temp_inputY *= inputYShape.data[i]; //-------------------------- if (conditionShape.data[i] == 1) { conditionIdx += 0; @@ -57,17 +57,15 @@ __global__ void _where_kernel(const float *inputx, const float *inputy, v = v / outputShape.data[i]; } output[outputIdx] = - condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx]; + condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx]; } } namespace infini { -void where_kernel(const float *inputx, const float *inputy, - const float *condition, float *output, int nDims, - infini::SmallArray inputxShape, - infini::SmallArray inputyShape, - infini::SmallArray conditionShape, - infini::SmallArray outputShape) { +void whereKernel(const float *inputX, const float *inputY, + const uint8_t *condition, float *output, int nDims, + SmallArray inputXShape, SmallArray inputYShape, + SmallArray conditionShape, SmallArray outputShape) { int outputsize = 1; for (int i = 0; i < nDims; i++) { @@ -75,8 +73,8 @@ void where_kernel(const float *inputx, const float *inputy, } int blocksize = 32 * 16; int gridsize = (outputsize + blocksize - 1) / blocksize; - _where_kernel<<>>( - inputx, inputy, condition, output, nDims, outputsize, inputxShape, - inputyShape, conditionShape, outputShape); + _whereKernel<<>>( + inputX, inputY, condition, output, nDims, outputsize, inputXShape, + inputYShape, conditionShape, outputShape); } } // namespace infini diff --git a/test/kernels/cuda/test_cuda_where.cc b/test/kernels/cuda/test_cuda_where.cc index 61515445..74f114d4 100644 --- a/test/kernels/cuda/test_cuda_where.cc +++ b/test/kernels/cuda/test_cuda_where.cc @@ -10,11 +10,12 @@ namespace infini { void test_where(const Shape &inputxshape, const vector &inputxdata, const Shape &inputyshape, const vector &inputydata, - const Shape &conditionshape, const vector &conditiondata, + const Shape &conditionshape, + const vector &conditiondata, const vector &ExpectData) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); - auto condition = gCpu->addTensor(conditionshape, DataType::Int32); + auto condition = gCpu->addTensor(conditionshape, DataType::UInt8); auto inputx = gCpu->addTensor(inputxshape, DataType::Float32); auto inputy = gCpu->addTensor(inputyshape, DataType::Float32); @@ -47,16 +48,37 @@ TEST(CUDA_Where, run) { test_where( Shape{2, 2, 3, 1}, vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, Shape{2, 2, 3, 1}, vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Shape{2, 2, 3, 1}, vector{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1}, + Shape{2, 2, 3, 1}, vector{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1}, vector{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.}); test_where(Shape{2, 1, 1, 3}, // inputx vector{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy vector{1, 1}, Shape{2, 1, 3, 1}, // condition - vector{0, 1, 1, 0, 0, 0}, + vector{0, 1, 1, 0, 0, 0}, vector{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + test_where( + Shape{ + 3, + }, + vector{0, 1, 2}, // inputX + Shape{2, 3, 1}, vector{0, 1, 2, 3, 4, 5}, // inputY + Shape{2, 1, 3, 1}, vector{0, 1, 1, 0, 0, 0}, // condition + vector{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3., + 0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1., + 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.}); + test_where( + Shape{ + 3, + }, + vector{0, 1, 2}, // inputX + Shape{2, 3, 1}, vector{0, 1, 2, 3, 4, 5}, // inputY + Shape{2, 1, 3, 1}, + vector{false, true, true, false, false, false}, // condition + vector{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3., + 0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1., + 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.}); } // python output