forked from jiuyuan/InfiniTensor
"modified where" (#131)
* "modified where" * "adapt int or bool condition datatype" * "add broadcast_shape.h,error" * add broadcast.h * "modified broadcast_shape.h and where.cc,.cu"
This commit is contained in:
parent
f60767a770
commit
dda668fd16
|
@ -3,7 +3,7 @@
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
#include "utils/small_array.h"
|
#include "utils/small_array.h"
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void expand_kernel(float *input, float *output, int nDims, int outputsize,
|
void expandKernel(float *input, float *output, int nDims, int outputsize,
|
||||||
SmallArray inputShape, SmallArray outputShape);
|
SmallArray inputShape, SmallArray outputShape);
|
||||||
|
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -3,11 +3,9 @@
|
||||||
#include "utils/small_array.h"
|
#include "utils/small_array.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void where_kernel(const float *inputx, const float *inputy,
|
void whereKernel(const float *inputX, const float *inputY,
|
||||||
const float *condition, float *output, int nDims,
|
const uint8_t *condition, float *output, int nDims,
|
||||||
infini::SmallArray inputxShape,
|
SmallArray inputXShape, SmallArray inputYShape,
|
||||||
infini::SmallArray inputyShape,
|
SmallArray conditionShape, SmallArray outputShape);
|
||||||
infini::SmallArray conditionShape,
|
|
||||||
infini::SmallArray outputShape);
|
|
||||||
|
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
||||||
|
int nDims, int size) {
|
||||||
|
for (int i = nDims - 1; i >= 0; --i) {
|
||||||
|
modifyShape.data[i] = 1;
|
||||||
|
}
|
||||||
|
for (int i = size - 1; i >= 0; --i) {
|
||||||
|
modifyShape.data[i + nDims - size] = originShape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -25,8 +25,8 @@ class ExpandCuda : public CudaKernelWithoutConfig {
|
||||||
inputShape.data[i] = in_Shape[i];
|
inputShape.data[i] = in_Shape[i];
|
||||||
outputsize *= out_Shape[i];
|
outputsize *= out_Shape[i];
|
||||||
}
|
}
|
||||||
expand_kernel((float *)inputData, (float *)outputData, nDims,
|
expandKernel((float *)inputData, (float *)outputData, nDims, outputsize,
|
||||||
outputsize, inputShape, outputShape);
|
inputShape, outputShape);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ constexpr unsigned int num_threads() { return 32 * 4; }
|
||||||
constexpr int thread_work_size() { return 4; }
|
constexpr int thread_work_size() { return 4; }
|
||||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||||
|
|
||||||
__global__ void _expand_kernel(float *input, float *output, int nDims,
|
__global__ void _expandKernel(float *input, float *output, int nDims,
|
||||||
int outputsize, infini::SmallArray inputShape,
|
int outputsize, infini::SmallArray inputShape,
|
||||||
infini::SmallArray outputShape) {
|
infini::SmallArray outputShape) {
|
||||||
|
|
||||||
|
@ -38,11 +38,11 @@ __global__ void _expand_kernel(float *input, float *output, int nDims,
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void expand_kernel(float *input, float *output, int nDims, int outputsize,
|
void expandKernel(float *input, float *output, int nDims, int outputsize,
|
||||||
SmallArray inputShape, SmallArray outputShape) {
|
SmallArray inputShape, SmallArray outputShape) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
|
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
|
||||||
_expand_kernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
|
_expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
|
||||||
inputShape, outputShape);
|
inputShape, outputShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
#include "cuda/cuda_where.h"
|
#include "cuda/cuda_where.h"
|
||||||
|
#include "utils/broadcast_shape.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
@ -10,28 +11,33 @@ class WhereCuda : public CudaKernelWithoutConfig {
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
auto op = as<WhereObj>(_op);
|
auto op = as<WhereObj>(_op);
|
||||||
|
|
||||||
void *const inputxData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const inputXData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const inputyData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
|
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
const auto &inputx_Shape = op->getInputs(0)->getDims();
|
const auto &opInputXShape = op->getInputs(0)->getDims();
|
||||||
const auto &inputy_Shape = op->getInputs(1)->getDims();
|
const auto &opInputYShape = op->getInputs(1)->getDims();
|
||||||
const auto &condition_Shape = op->getInputs(2)->getDims();
|
const auto &opConditionShape = op->getInputs(2)->getDims();
|
||||||
const auto &output_Shape = op->getOutput()->getDims();
|
const auto &opOutputShape = op->getOutput()->getDims();
|
||||||
|
|
||||||
int nDims = op->getInputs(0)->getDims().size();
|
const int xSize = op->getInputs(0)->getRank();
|
||||||
|
const int ySize = op->getInputs(1)->getRank();
|
||||||
|
const int cSize = op->getInputs(2)->getRank();
|
||||||
|
int nDims = op->getOutput()->getDims().size();
|
||||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||||
|
|
||||||
SmallArray inputxShape, inputyShape, conditionShape, outputShape;
|
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = nDims - 1; i >= 0; --i) {
|
||||||
inputxShape.data[i] = inputx_Shape[i];
|
outputShape.data[i] = opOutputShape[i];
|
||||||
inputyShape.data[i] = inputy_Shape[i];
|
|
||||||
conditionShape.data[i] = condition_Shape[i];
|
|
||||||
outputShape.data[i] = output_Shape[i];
|
|
||||||
}
|
}
|
||||||
where_kernel((float *)inputxData, (float *)inputyData,
|
|
||||||
(float *)conditionData, (float *)outputData, nDims,
|
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
|
||||||
inputxShape, inputyShape, conditionShape, outputShape);
|
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
|
||||||
|
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
|
||||||
|
|
||||||
|
whereKernel((float *)inputXData, (float *)inputYData,
|
||||||
|
(uint8_t *)conditionData, (float *)outputData, nDims,
|
||||||
|
inputXShape, inputYShape, conditionShape, outputShape);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,20 @@
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
#include "utils/small_array.h"
|
#include "utils/small_array.h"
|
||||||
|
|
||||||
__global__ void _where_kernel(const float *inputx, const float *inputy,
|
__global__ void _whereKernel(const float *inputX, const float *inputY,
|
||||||
const float *condition, float *output, int nDims,
|
const uint8_t *condition, float *output, int nDims,
|
||||||
int outputsize, infini::SmallArray inputxShape,
|
int outputsize, infini::SmallArray inputXShape,
|
||||||
infini::SmallArray inputyShape,
|
infini::SmallArray inputYShape,
|
||||||
infini::SmallArray conditionShape,
|
infini::SmallArray conditionShape,
|
||||||
infini::SmallArray outputShape) {
|
infini::SmallArray outputShape) {
|
||||||
|
|
||||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (outputIdx < outputsize) {
|
if (outputIdx < outputsize) {
|
||||||
int inputxIdx = 0;
|
int inputXIdx = 0;
|
||||||
int temp_inputx = 1;
|
int temp_inputX = 1;
|
||||||
|
|
||||||
int inputyIdx = 0;
|
int inputYIdx = 0;
|
||||||
int temp_inputy = 1;
|
int temp_inputY = 1;
|
||||||
|
|
||||||
int conditionIdx = 0;
|
int conditionIdx = 0;
|
||||||
int temp_condition = 1;
|
int temp_condition = 1;
|
||||||
|
@ -27,23 +27,23 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
|
||||||
} else {
|
} else {
|
||||||
tmp = v % outputShape.data[i]; // store s,k,j in order
|
tmp = v % outputShape.data[i]; // store s,k,j in order
|
||||||
}
|
}
|
||||||
if (inputxShape.data[i] == 1) {
|
if (inputXShape.data[i] == 1) {
|
||||||
inputxIdx += 0;
|
inputXIdx += 0;
|
||||||
} else {
|
} else {
|
||||||
inputxIdx +=
|
inputXIdx +=
|
||||||
tmp *
|
tmp *
|
||||||
temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s
|
temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||||
}
|
}
|
||||||
temp_inputx *= inputxShape.data[i];
|
temp_inputX *= inputXShape.data[i];
|
||||||
//----------------------------
|
//----------------------------
|
||||||
if (inputyShape.data[i] == 1) {
|
if (inputYShape.data[i] == 1) {
|
||||||
inputyIdx += 0;
|
inputYIdx += 0;
|
||||||
} else {
|
} else {
|
||||||
inputyIdx +=
|
inputYIdx +=
|
||||||
tmp *
|
tmp *
|
||||||
temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s
|
temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||||
}
|
}
|
||||||
temp_inputy *= inputyShape.data[i];
|
temp_inputY *= inputYShape.data[i];
|
||||||
//--------------------------
|
//--------------------------
|
||||||
if (conditionShape.data[i] == 1) {
|
if (conditionShape.data[i] == 1) {
|
||||||
conditionIdx += 0;
|
conditionIdx += 0;
|
||||||
|
@ -57,17 +57,15 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
|
||||||
v = v / outputShape.data[i];
|
v = v / outputShape.data[i];
|
||||||
}
|
}
|
||||||
output[outputIdx] =
|
output[outputIdx] =
|
||||||
condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx];
|
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void where_kernel(const float *inputx, const float *inputy,
|
void whereKernel(const float *inputX, const float *inputY,
|
||||||
const float *condition, float *output, int nDims,
|
const uint8_t *condition, float *output, int nDims,
|
||||||
infini::SmallArray inputxShape,
|
SmallArray inputXShape, SmallArray inputYShape,
|
||||||
infini::SmallArray inputyShape,
|
SmallArray conditionShape, SmallArray outputShape) {
|
||||||
infini::SmallArray conditionShape,
|
|
||||||
infini::SmallArray outputShape) {
|
|
||||||
int outputsize = 1;
|
int outputsize = 1;
|
||||||
|
|
||||||
for (int i = 0; i < nDims; i++) {
|
for (int i = 0; i < nDims; i++) {
|
||||||
|
@ -75,8 +73,8 @@ void where_kernel(const float *inputx, const float *inputy,
|
||||||
}
|
}
|
||||||
int blocksize = 32 * 16;
|
int blocksize = 32 * 16;
|
||||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||||
_where_kernel<<<gridsize, blocksize>>>(
|
_whereKernel<<<gridsize, blocksize>>>(
|
||||||
inputx, inputy, condition, output, nDims, outputsize, inputxShape,
|
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||||
inputyShape, conditionShape, outputShape);
|
inputYShape, conditionShape, outputShape);
|
||||||
}
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -10,11 +10,12 @@ namespace infini {
|
||||||
|
|
||||||
void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
|
void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
|
||||||
const Shape &inputyshape, const vector<float> &inputydata,
|
const Shape &inputyshape, const vector<float> &inputydata,
|
||||||
const Shape &conditionshape, const vector<int> &conditiondata,
|
const Shape &conditionshape,
|
||||||
|
const vector<uint8_t> &conditiondata,
|
||||||
const vector<float> &ExpectData) {
|
const vector<float> &ExpectData) {
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
auto condition = gCpu->addTensor(conditionshape, DataType::Int32);
|
auto condition = gCpu->addTensor(conditionshape, DataType::UInt8);
|
||||||
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
|
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
|
||||||
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
|
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
|
||||||
|
|
||||||
|
@ -47,16 +48,37 @@ TEST(CUDA_Where, run) {
|
||||||
test_where(
|
test_where(
|
||||||
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||||
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
Shape{2, 2, 3, 1}, vector<int>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
Shape{2, 2, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
||||||
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
|
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
|
||||||
|
|
||||||
test_where(Shape{2, 1, 1, 3}, // inputx
|
test_where(Shape{2, 1, 1, 3}, // inputx
|
||||||
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
|
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
|
||||||
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
|
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
|
||||||
vector<int>{0, 1, 1, 0, 0, 0},
|
vector<uint8_t>{0, 1, 1, 0, 0, 0},
|
||||||
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
|
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
|
||||||
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
|
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
|
||||||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
|
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
|
||||||
|
test_where(
|
||||||
|
Shape{
|
||||||
|
3,
|
||||||
|
},
|
||||||
|
vector<float>{0, 1, 2}, // inputX
|
||||||
|
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
||||||
|
Shape{2, 1, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0}, // condition
|
||||||
|
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
||||||
|
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
||||||
|
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
||||||
|
test_where(
|
||||||
|
Shape{
|
||||||
|
3,
|
||||||
|
},
|
||||||
|
vector<float>{0, 1, 2}, // inputX
|
||||||
|
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
||||||
|
Shape{2, 1, 3, 1},
|
||||||
|
vector<uint8_t>{false, true, true, false, false, false}, // condition
|
||||||
|
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
||||||
|
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
||||||
|
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
||||||
|
|
||||||
} // python output
|
} // python output
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue