"modified where" (#131)

* "modified where"

* "adapt int or bool condition datatype"

* "add broadcast_shape.h,error"

* add broadcast.h

* "modified broadcast_shape.h and where.cc,.cu"
This commit is contained in:
xgqdut2016 2023-09-14 10:45:57 +08:00 committed by GitHub
parent f60767a770
commit dda668fd16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 105 additions and 67 deletions

View File

@ -3,7 +3,7 @@
#include "operators/unary.h"
#include "utils/small_array.h"
namespace infini {
void expand_kernel(float *input, float *output, int nDims, int outputsize,
void expandKernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape);
}; // namespace infini

View File

@ -3,11 +3,9 @@
#include "utils/small_array.h"
namespace infini {
void where_kernel(const float *inputx, const float *inputy,
const float *condition, float *output, int nDims,
infini::SmallArray inputxShape,
infini::SmallArray inputyShape,
infini::SmallArray conditionShape,
infini::SmallArray outputShape);
void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape);
}; // namespace infini

View File

@ -0,0 +1,14 @@
#pragma once
namespace infini {
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
int nDims, int size) {
for (int i = nDims - 1; i >= 0; --i) {
modifyShape.data[i] = 1;
}
for (int i = size - 1; i >= 0; --i) {
modifyShape.data[i + nDims - size] = originShape[i];
}
}
} // namespace infini

View File

@ -25,8 +25,8 @@ class ExpandCuda : public CudaKernelWithoutConfig {
inputShape.data[i] = in_Shape[i];
outputsize *= out_Shape[i];
}
expand_kernel((float *)inputData, (float *)outputData, nDims,
outputsize, inputShape, outputShape);
expandKernel((float *)inputData, (float *)outputData, nDims, outputsize,
inputShape, outputShape);
}
};

View File

@ -6,7 +6,7 @@ constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _expand_kernel(float *input, float *output, int nDims,
__global__ void _expandKernel(float *input, float *output, int nDims,
int outputsize, infini::SmallArray inputShape,
infini::SmallArray outputShape) {
@ -38,11 +38,11 @@ __global__ void _expand_kernel(float *input, float *output, int nDims,
}
namespace infini {
void expand_kernel(float *input, float *output, int nDims, int outputsize,
void expandKernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape) {
int blocksize = block_work_size();
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
_expand_kernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
_expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
inputShape, outputShape);
}

View File

@ -2,6 +2,7 @@
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_where.h"
#include "utils/broadcast_shape.h"
namespace infini {
@ -10,28 +11,33 @@ class WhereCuda : public CudaKernelWithoutConfig {
const RuntimeObj *_context) const override {
auto op = as<WhereObj>(_op);
void *const inputxData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inputyData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const inputXData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &inputx_Shape = op->getInputs(0)->getDims();
const auto &inputy_Shape = op->getInputs(1)->getDims();
const auto &condition_Shape = op->getInputs(2)->getDims();
const auto &output_Shape = op->getOutput()->getDims();
const auto &opInputXShape = op->getInputs(0)->getDims();
const auto &opInputYShape = op->getInputs(1)->getDims();
const auto &opConditionShape = op->getInputs(2)->getDims();
const auto &opOutputShape = op->getOutput()->getDims();
int nDims = op->getInputs(0)->getDims().size();
const int xSize = op->getInputs(0)->getRank();
const int ySize = op->getInputs(1)->getRank();
const int cSize = op->getInputs(2)->getRank();
int nDims = op->getOutput()->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
SmallArray inputxShape, inputyShape, conditionShape, outputShape;
for (int i = 0; i < nDims; ++i) {
inputxShape.data[i] = inputx_Shape[i];
inputyShape.data[i] = inputy_Shape[i];
conditionShape.data[i] = condition_Shape[i];
outputShape.data[i] = output_Shape[i];
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
for (int i = nDims - 1; i >= 0; --i) {
outputShape.data[i] = opOutputShape[i];
}
where_kernel((float *)inputxData, (float *)inputyData,
(float *)conditionData, (float *)outputData, nDims,
inputxShape, inputyShape, conditionShape, outputShape);
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
whereKernel((float *)inputXData, (float *)inputYData,
(uint8_t *)conditionData, (float *)outputData, nDims,
inputXShape, inputYShape, conditionShape, outputShape);
}
};

View File

@ -1,20 +1,20 @@
#include "cuda/cuda_common.h"
#include "utils/small_array.h"
__global__ void _where_kernel(const float *inputx, const float *inputy,
const float *condition, float *output, int nDims,
int outputsize, infini::SmallArray inputxShape,
infini::SmallArray inputyShape,
__global__ void _whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
int outputsize, infini::SmallArray inputXShape,
infini::SmallArray inputYShape,
infini::SmallArray conditionShape,
infini::SmallArray outputShape) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < outputsize) {
int inputxIdx = 0;
int temp_inputx = 1;
int inputXIdx = 0;
int temp_inputX = 1;
int inputyIdx = 0;
int temp_inputy = 1;
int inputYIdx = 0;
int temp_inputY = 1;
int conditionIdx = 0;
int temp_condition = 1;
@ -27,23 +27,23 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
} else {
tmp = v % outputShape.data[i]; // store s,k,j in order
}
if (inputxShape.data[i] == 1) {
inputxIdx += 0;
if (inputXShape.data[i] == 1) {
inputXIdx += 0;
} else {
inputxIdx +=
inputXIdx +=
tmp *
temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s
temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp_inputx *= inputxShape.data[i];
temp_inputX *= inputXShape.data[i];
//----------------------------
if (inputyShape.data[i] == 1) {
inputyIdx += 0;
if (inputYShape.data[i] == 1) {
inputYIdx += 0;
} else {
inputyIdx +=
inputYIdx +=
tmp *
temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s
temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp_inputy *= inputyShape.data[i];
temp_inputY *= inputYShape.data[i];
//--------------------------
if (conditionShape.data[i] == 1) {
conditionIdx += 0;
@ -57,17 +57,15 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
v = v / outputShape.data[i];
}
output[outputIdx] =
condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx];
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
}
}
namespace infini {
void where_kernel(const float *inputx, const float *inputy,
const float *condition, float *output, int nDims,
infini::SmallArray inputxShape,
infini::SmallArray inputyShape,
infini::SmallArray conditionShape,
infini::SmallArray outputShape) {
void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape) {
int outputsize = 1;
for (int i = 0; i < nDims; i++) {
@ -75,8 +73,8 @@ void where_kernel(const float *inputx, const float *inputy,
}
int blocksize = 32 * 16;
int gridsize = (outputsize + blocksize - 1) / blocksize;
_where_kernel<<<gridsize, blocksize>>>(
inputx, inputy, condition, output, nDims, outputsize, inputxShape,
inputyShape, conditionShape, outputShape);
_whereKernel<<<gridsize, blocksize>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape);
}
} // namespace infini

View File

@ -10,11 +10,12 @@ namespace infini {
void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
const Shape &inputyshape, const vector<float> &inputydata,
const Shape &conditionshape, const vector<int> &conditiondata,
const Shape &conditionshape,
const vector<uint8_t> &conditiondata,
const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto condition = gCpu->addTensor(conditionshape, DataType::Int32);
auto condition = gCpu->addTensor(conditionshape, DataType::UInt8);
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
@ -47,16 +48,37 @@ TEST(CUDA_Where, run) {
test_where(
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
Shape{2, 2, 3, 1}, vector<int>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
Shape{2, 2, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
test_where(Shape{2, 1, 1, 3}, // inputx
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
vector<int>{0, 1, 1, 0, 0, 0},
vector<uint8_t>{0, 1, 1, 0, 0, 0},
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
test_where(
Shape{
3,
},
vector<float>{0, 1, 2}, // inputX
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
Shape{2, 1, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0}, // condition
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
test_where(
Shape{
3,
},
vector<float>{0, 1, 2}, // inputX
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
Shape{2, 1, 3, 1},
vector<uint8_t>{false, true, true, false, false, false}, // condition
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
} // python output