forked from jiuyuan/InfiniTensor
"modified where" (#131)
* "modified where" * "adapt int or bool condition datatype" * "add broadcast_shape.h,error" * add broadcast.h * "modified broadcast_shape.h and where.cc,.cu"
This commit is contained in:
parent
f60767a770
commit
dda668fd16
|
@ -3,7 +3,7 @@
|
|||
#include "operators/unary.h"
|
||||
#include "utils/small_array.h"
|
||||
namespace infini {
|
||||
void expand_kernel(float *input, float *output, int nDims, int outputsize,
|
||||
void expandKernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,11 +3,9 @@
|
|||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
void where_kernel(const float *inputx, const float *inputy,
|
||||
const float *condition, float *output, int nDims,
|
||||
infini::SmallArray inputxShape,
|
||||
infini::SmallArray inputyShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape);
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
||||
int nDims, int size) {
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
modifyShape.data[i] = 1;
|
||||
}
|
||||
for (int i = size - 1; i >= 0; --i) {
|
||||
modifyShape.data[i + nDims - size] = originShape[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -25,8 +25,8 @@ class ExpandCuda : public CudaKernelWithoutConfig {
|
|||
inputShape.data[i] = in_Shape[i];
|
||||
outputsize *= out_Shape[i];
|
||||
}
|
||||
expand_kernel((float *)inputData, (float *)outputData, nDims,
|
||||
outputsize, inputShape, outputShape);
|
||||
expandKernel((float *)inputData, (float *)outputData, nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ constexpr unsigned int num_threads() { return 32 * 4; }
|
|||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _expand_kernel(float *input, float *output, int nDims,
|
||||
__global__ void _expandKernel(float *input, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputShape,
|
||||
infini::SmallArray outputShape) {
|
||||
|
||||
|
@ -38,11 +38,11 @@ __global__ void _expand_kernel(float *input, float *output, int nDims,
|
|||
}
|
||||
|
||||
namespace infini {
|
||||
void expand_kernel(float *input, float *output, int nDims, int outputsize,
|
||||
void expandKernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
|
||||
_expand_kernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
|
||||
_expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_where.h"
|
||||
#include "utils/broadcast_shape.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -10,28 +11,33 @@ class WhereCuda : public CudaKernelWithoutConfig {
|
|||
const RuntimeObj *_context) const override {
|
||||
auto op = as<WhereObj>(_op);
|
||||
|
||||
void *const inputxData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const inputyData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const inputXData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
const auto &inputx_Shape = op->getInputs(0)->getDims();
|
||||
const auto &inputy_Shape = op->getInputs(1)->getDims();
|
||||
const auto &condition_Shape = op->getInputs(2)->getDims();
|
||||
const auto &output_Shape = op->getOutput()->getDims();
|
||||
const auto &opInputXShape = op->getInputs(0)->getDims();
|
||||
const auto &opInputYShape = op->getInputs(1)->getDims();
|
||||
const auto &opConditionShape = op->getInputs(2)->getDims();
|
||||
const auto &opOutputShape = op->getOutput()->getDims();
|
||||
|
||||
int nDims = op->getInputs(0)->getDims().size();
|
||||
const int xSize = op->getInputs(0)->getRank();
|
||||
const int ySize = op->getInputs(1)->getRank();
|
||||
const int cSize = op->getInputs(2)->getRank();
|
||||
int nDims = op->getOutput()->getDims().size();
|
||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||
|
||||
SmallArray inputxShape, inputyShape, conditionShape, outputShape;
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
inputxShape.data[i] = inputx_Shape[i];
|
||||
inputyShape.data[i] = inputy_Shape[i];
|
||||
conditionShape.data[i] = condition_Shape[i];
|
||||
outputShape.data[i] = output_Shape[i];
|
||||
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
outputShape.data[i] = opOutputShape[i];
|
||||
}
|
||||
where_kernel((float *)inputxData, (float *)inputyData,
|
||||
(float *)conditionData, (float *)outputData, nDims,
|
||||
inputxShape, inputyShape, conditionShape, outputShape);
|
||||
|
||||
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
|
||||
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
|
||||
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
|
||||
|
||||
whereKernel((float *)inputXData, (float *)inputYData,
|
||||
(uint8_t *)conditionData, (float *)outputData, nDims,
|
||||
inputXShape, inputYShape, conditionShape, outputShape);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
__global__ void _where_kernel(const float *inputx, const float *inputy,
|
||||
const float *condition, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputxShape,
|
||||
infini::SmallArray inputyShape,
|
||||
__global__ void _whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputXShape,
|
||||
infini::SmallArray inputYShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape) {
|
||||
|
||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (outputIdx < outputsize) {
|
||||
int inputxIdx = 0;
|
||||
int temp_inputx = 1;
|
||||
int inputXIdx = 0;
|
||||
int temp_inputX = 1;
|
||||
|
||||
int inputyIdx = 0;
|
||||
int temp_inputy = 1;
|
||||
int inputYIdx = 0;
|
||||
int temp_inputY = 1;
|
||||
|
||||
int conditionIdx = 0;
|
||||
int temp_condition = 1;
|
||||
|
@ -27,23 +27,23 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
|
|||
} else {
|
||||
tmp = v % outputShape.data[i]; // store s,k,j in order
|
||||
}
|
||||
if (inputxShape.data[i] == 1) {
|
||||
inputxIdx += 0;
|
||||
if (inputXShape.data[i] == 1) {
|
||||
inputXIdx += 0;
|
||||
} else {
|
||||
inputxIdx +=
|
||||
inputXIdx +=
|
||||
tmp *
|
||||
temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_inputx *= inputxShape.data[i];
|
||||
temp_inputX *= inputXShape.data[i];
|
||||
//----------------------------
|
||||
if (inputyShape.data[i] == 1) {
|
||||
inputyIdx += 0;
|
||||
if (inputYShape.data[i] == 1) {
|
||||
inputYIdx += 0;
|
||||
} else {
|
||||
inputyIdx +=
|
||||
inputYIdx +=
|
||||
tmp *
|
||||
temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_inputy *= inputyShape.data[i];
|
||||
temp_inputY *= inputYShape.data[i];
|
||||
//--------------------------
|
||||
if (conditionShape.data[i] == 1) {
|
||||
conditionIdx += 0;
|
||||
|
@ -57,17 +57,15 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
|
|||
v = v / outputShape.data[i];
|
||||
}
|
||||
output[outputIdx] =
|
||||
condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx];
|
||||
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void where_kernel(const float *inputx, const float *inputy,
|
||||
const float *condition, float *output, int nDims,
|
||||
infini::SmallArray inputxShape,
|
||||
infini::SmallArray inputyShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape) {
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape) {
|
||||
int outputsize = 1;
|
||||
|
||||
for (int i = 0; i < nDims; i++) {
|
||||
|
@ -75,8 +73,8 @@ void where_kernel(const float *inputx, const float *inputy,
|
|||
}
|
||||
int blocksize = 32 * 16;
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_where_kernel<<<gridsize, blocksize>>>(
|
||||
inputx, inputy, condition, output, nDims, outputsize, inputxShape,
|
||||
inputyShape, conditionShape, outputShape);
|
||||
_whereKernel<<<gridsize, blocksize>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -10,11 +10,12 @@ namespace infini {
|
|||
|
||||
void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
|
||||
const Shape &inputyshape, const vector<float> &inputydata,
|
||||
const Shape &conditionshape, const vector<int> &conditiondata,
|
||||
const Shape &conditionshape,
|
||||
const vector<uint8_t> &conditiondata,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto condition = gCpu->addTensor(conditionshape, DataType::Int32);
|
||||
auto condition = gCpu->addTensor(conditionshape, DataType::UInt8);
|
||||
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
|
||||
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
|
||||
|
||||
|
@ -47,16 +48,37 @@ TEST(CUDA_Where, run) {
|
|||
test_where(
|
||||
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
Shape{2, 2, 3, 1}, vector<int>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
||||
Shape{2, 2, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
||||
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
|
||||
|
||||
test_where(Shape{2, 1, 1, 3}, // inputx
|
||||
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
|
||||
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
|
||||
vector<int>{0, 1, 1, 0, 0, 0},
|
||||
vector<uint8_t>{0, 1, 1, 0, 0, 0},
|
||||
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
|
||||
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
|
||||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
|
||||
test_where(
|
||||
Shape{
|
||||
3,
|
||||
},
|
||||
vector<float>{0, 1, 2}, // inputX
|
||||
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
||||
Shape{2, 1, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0}, // condition
|
||||
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
||||
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
||||
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
||||
test_where(
|
||||
Shape{
|
||||
3,
|
||||
},
|
||||
vector<float>{0, 1, 2}, // inputX
|
||||
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
||||
Shape{2, 1, 3, 1},
|
||||
vector<uint8_t>{false, true, true, false, false, false}, // condition
|
||||
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
||||
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
||||
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
||||
|
||||
} // python output
|
||||
|
||||
|
|
Loading…
Reference in New Issue