"modified where" (#131)

* "modified where"

* "adapt int or bool condition datatype"

* "add broadcast_shape.h,error"

* add broadcast.h

* "modified broadcast_shape.h and where.cc,.cu"
This commit is contained in:
xgqdut2016 2023-09-14 10:45:57 +08:00 committed by GitHub
parent f60767a770
commit dda668fd16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 105 additions and 67 deletions

View File

@ -3,7 +3,7 @@
#include "operators/unary.h" #include "operators/unary.h"
#include "utils/small_array.h" #include "utils/small_array.h"
namespace infini { namespace infini {
void expand_kernel(float *input, float *output, int nDims, int outputsize, void expandKernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape); SmallArray inputShape, SmallArray outputShape);
}; // namespace infini }; // namespace infini

View File

@ -3,11 +3,9 @@
#include "utils/small_array.h" #include "utils/small_array.h"
namespace infini { namespace infini {
void where_kernel(const float *inputx, const float *inputy, void whereKernel(const float *inputX, const float *inputY,
const float *condition, float *output, int nDims, const uint8_t *condition, float *output, int nDims,
infini::SmallArray inputxShape, SmallArray inputXShape, SmallArray inputYShape,
infini::SmallArray inputyShape, SmallArray conditionShape, SmallArray outputShape);
infini::SmallArray conditionShape,
infini::SmallArray outputShape);
}; // namespace infini }; // namespace infini

View File

@ -0,0 +1,14 @@
#pragma once
namespace infini {
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
int nDims, int size) {
for (int i = nDims - 1; i >= 0; --i) {
modifyShape.data[i] = 1;
}
for (int i = size - 1; i >= 0; --i) {
modifyShape.data[i + nDims - size] = originShape[i];
}
}
} // namespace infini

View File

@ -25,8 +25,8 @@ class ExpandCuda : public CudaKernelWithoutConfig {
inputShape.data[i] = in_Shape[i]; inputShape.data[i] = in_Shape[i];
outputsize *= out_Shape[i]; outputsize *= out_Shape[i];
} }
expand_kernel((float *)inputData, (float *)outputData, nDims, expandKernel((float *)inputData, (float *)outputData, nDims, outputsize,
outputsize, inputShape, outputShape); inputShape, outputShape);
} }
}; };

View File

@ -6,7 +6,7 @@ constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; } constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); } constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _expand_kernel(float *input, float *output, int nDims, __global__ void _expandKernel(float *input, float *output, int nDims,
int outputsize, infini::SmallArray inputShape, int outputsize, infini::SmallArray inputShape,
infini::SmallArray outputShape) { infini::SmallArray outputShape) {
@ -38,11 +38,11 @@ __global__ void _expand_kernel(float *input, float *output, int nDims,
} }
namespace infini { namespace infini {
void expand_kernel(float *input, float *output, int nDims, int outputsize, void expandKernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape) { SmallArray inputShape, SmallArray outputShape) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (outputsize + block_work_size() - 1) / block_work_size(); int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
_expand_kernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize, _expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
inputShape, outputShape); inputShape, outputShape);
} }

View File

@ -2,6 +2,7 @@
#include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h" #include "cuda/cuda_runtime.h"
#include "cuda/cuda_where.h" #include "cuda/cuda_where.h"
#include "utils/broadcast_shape.h"
namespace infini { namespace infini {
@ -10,28 +11,33 @@ class WhereCuda : public CudaKernelWithoutConfig {
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<WhereObj>(_op); auto op = as<WhereObj>(_op);
void *const inputxData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inputXData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inputyData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>()); void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>()); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &inputx_Shape = op->getInputs(0)->getDims(); const auto &opInputXShape = op->getInputs(0)->getDims();
const auto &inputy_Shape = op->getInputs(1)->getDims(); const auto &opInputYShape = op->getInputs(1)->getDims();
const auto &condition_Shape = op->getInputs(2)->getDims(); const auto &opConditionShape = op->getInputs(2)->getDims();
const auto &output_Shape = op->getOutput()->getDims(); const auto &opOutputShape = op->getOutput()->getDims();
int nDims = op->getInputs(0)->getDims().size(); const int xSize = op->getInputs(0)->getRank();
const int ySize = op->getInputs(1)->getRank();
const int cSize = op->getInputs(2)->getRank();
int nDims = op->getOutput()->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
SmallArray inputxShape, inputyShape, conditionShape, outputShape; SmallArray inputXShape, inputYShape, conditionShape, outputShape;
for (int i = 0; i < nDims; ++i) { for (int i = nDims - 1; i >= 0; --i) {
inputxShape.data[i] = inputx_Shape[i]; outputShape.data[i] = opOutputShape[i];
inputyShape.data[i] = inputy_Shape[i];
conditionShape.data[i] = condition_Shape[i];
outputShape.data[i] = output_Shape[i];
} }
where_kernel((float *)inputxData, (float *)inputyData,
(float *)conditionData, (float *)outputData, nDims, broadcastShape(opInputXShape, inputXShape, nDims, xSize);
inputxShape, inputyShape, conditionShape, outputShape); broadcastShape(opInputYShape, inputYShape, nDims, ySize);
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
whereKernel((float *)inputXData, (float *)inputYData,
(uint8_t *)conditionData, (float *)outputData, nDims,
inputXShape, inputYShape, conditionShape, outputShape);
} }
}; };

View File

@ -1,20 +1,20 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "utils/small_array.h" #include "utils/small_array.h"
__global__ void _where_kernel(const float *inputx, const float *inputy, __global__ void _whereKernel(const float *inputX, const float *inputY,
const float *condition, float *output, int nDims, const uint8_t *condition, float *output, int nDims,
int outputsize, infini::SmallArray inputxShape, int outputsize, infini::SmallArray inputXShape,
infini::SmallArray inputyShape, infini::SmallArray inputYShape,
infini::SmallArray conditionShape, infini::SmallArray conditionShape,
infini::SmallArray outputShape) { infini::SmallArray outputShape) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < outputsize) { if (outputIdx < outputsize) {
int inputxIdx = 0; int inputXIdx = 0;
int temp_inputx = 1; int temp_inputX = 1;
int inputyIdx = 0; int inputYIdx = 0;
int temp_inputy = 1; int temp_inputY = 1;
int conditionIdx = 0; int conditionIdx = 0;
int temp_condition = 1; int temp_condition = 1;
@ -27,23 +27,23 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
} else { } else {
tmp = v % outputShape.data[i]; // store s,k,j in order tmp = v % outputShape.data[i]; // store s,k,j in order
} }
if (inputxShape.data[i] == 1) { if (inputXShape.data[i] == 1) {
inputxIdx += 0; inputXIdx += 0;
} else { } else {
inputxIdx += inputXIdx +=
tmp * tmp *
temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s
} }
temp_inputx *= inputxShape.data[i]; temp_inputX *= inputXShape.data[i];
//---------------------------- //----------------------------
if (inputyShape.data[i] == 1) { if (inputYShape.data[i] == 1) {
inputyIdx += 0; inputYIdx += 0;
} else { } else {
inputyIdx += inputYIdx +=
tmp * tmp *
temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s
} }
temp_inputy *= inputyShape.data[i]; temp_inputY *= inputYShape.data[i];
//-------------------------- //--------------------------
if (conditionShape.data[i] == 1) { if (conditionShape.data[i] == 1) {
conditionIdx += 0; conditionIdx += 0;
@ -57,17 +57,15 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
v = v / outputShape.data[i]; v = v / outputShape.data[i];
} }
output[outputIdx] = output[outputIdx] =
condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx]; condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
} }
} }
namespace infini { namespace infini {
void where_kernel(const float *inputx, const float *inputy, void whereKernel(const float *inputX, const float *inputY,
const float *condition, float *output, int nDims, const uint8_t *condition, float *output, int nDims,
infini::SmallArray inputxShape, SmallArray inputXShape, SmallArray inputYShape,
infini::SmallArray inputyShape, SmallArray conditionShape, SmallArray outputShape) {
infini::SmallArray conditionShape,
infini::SmallArray outputShape) {
int outputsize = 1; int outputsize = 1;
for (int i = 0; i < nDims; i++) { for (int i = 0; i < nDims; i++) {
@ -75,8 +73,8 @@ void where_kernel(const float *inputx, const float *inputy,
} }
int blocksize = 32 * 16; int blocksize = 32 * 16;
int gridsize = (outputsize + blocksize - 1) / blocksize; int gridsize = (outputsize + blocksize - 1) / blocksize;
_where_kernel<<<gridsize, blocksize>>>( _whereKernel<<<gridsize, blocksize>>>(
inputx, inputy, condition, output, nDims, outputsize, inputxShape, inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputyShape, conditionShape, outputShape); inputYShape, conditionShape, outputShape);
} }
} // namespace infini } // namespace infini

View File

@ -10,11 +10,12 @@ namespace infini {
void test_where(const Shape &inputxshape, const vector<float> &inputxdata, void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
const Shape &inputyshape, const vector<float> &inputydata, const Shape &inputyshape, const vector<float> &inputydata,
const Shape &conditionshape, const vector<int> &conditiondata, const Shape &conditionshape,
const vector<uint8_t> &conditiondata,
const vector<float> &ExpectData) { const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime); Graph gCpu = make_ref<GraphObj>(runtime);
auto condition = gCpu->addTensor(conditionshape, DataType::Int32); auto condition = gCpu->addTensor(conditionshape, DataType::UInt8);
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32); auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32); auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
@ -47,16 +48,37 @@ TEST(CUDA_Where, run) {
test_where( test_where(
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
Shape{2, 2, 3, 1}, vector<int>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1}, Shape{2, 2, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.}); vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
test_where(Shape{2, 1, 1, 3}, // inputx test_where(Shape{2, 1, 1, 3}, // inputx
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
vector<int>{0, 1, 1, 0, 0, 0}, vector<uint8_t>{0, 1, 1, 0, 0, 0},
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1., vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
test_where(
Shape{
3,
},
vector<float>{0, 1, 2}, // inputX
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
Shape{2, 1, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0}, // condition
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
test_where(
Shape{
3,
},
vector<float>{0, 1, 2}, // inputX
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
Shape{2, 1, 3, 1},
vector<uint8_t>{false, true, true, false, false, false}, // condition
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
} // python output } // python output