forked from jiuyuan/InfiniTensor
Compare commits
5 Commits
master
...
cuda-trans
Author | SHA1 | Date |
---|---|---|
xgqdut2016 | 7146294baa | |
xgqdut2016 | 73e3f1fc6f | |
xgqdut2016 | 86133c8d0a | |
xgqdut2016 | 2761d46737 | |
xgqdut2016 | aa1c3222ed |
|
@ -3,10 +3,11 @@
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
#include "utils/small_array.h"
|
#include "utils/small_array.h"
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
void expandKernel(int dType, void *input, void *output, int a0, int a1, int a2,
|
||||||
|
int a3, int b0, int b1, int b2, int b3);
|
||||||
void expandKernel(int dType, void *input, void *output, int nDims,
|
void expandKernel(int dType, void *input, void *output, int nDims,
|
||||||
int outputsize, SmallArray inputShape,
|
int outputsize, SmallArray inputShape,
|
||||||
SmallArray outputShape);
|
SmallArray outputShape);
|
||||||
|
|
||||||
void expandRowKernel(int dType, void *input, void *output, int n_rows,
|
void expandRowKernel(int dType, void *input, void *output, int n_rows,
|
||||||
int row_len);
|
int row_len);
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
#include "utils/small_array.h"
|
#include "utils/small_array.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
void whereKernel(const float *inputX, const float *inputY,
|
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
|
||||||
const uint8_t *condition, float *output, int nDims,
|
const uint8_t *condition, void *output, int a0, int a1, int a2,
|
||||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
int c3, int d0, int d1, int d2, int d3);
|
||||||
int ySize, int cSize);
|
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
|
||||||
void whereKernel(const half *inputX, const half *inputY,
|
const uint8_t *condition, void *output, int nDims,
|
||||||
const uint8_t *condition, half *output, int nDims,
|
|
||||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||||
int ySize, int cSize);
|
int ySize, int cSize);
|
||||||
|
|
|
@ -5,34 +5,42 @@
|
||||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||||
constexpr int thread_work_size() { return 4; }
|
constexpr int thread_work_size() { return 4; }
|
||||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||||
|
const int repeat = 1;
|
||||||
template <class T>
|
template <class T>
|
||||||
__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||||
int c1, int c2, int c3) {
|
int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
|
||||||
int n = c0 * c1 * c2 * c3;
|
|
||||||
|
|
||||||
for (int i = index; i < n; i += stride) {
|
int stride1 = c2 * c3;
|
||||||
int c0_index = i / (c1 * c2 * c3);
|
int stride0 = c1 * stride1;
|
||||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
int n = c0 * stride0;
|
||||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||||
|
|
||||||
int a0_index = c0_index % a0;
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
int a1_index = c1_index % a1;
|
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||||
int a2_index = c2_index % a2;
|
if (aIdx || bIdx) {
|
||||||
int a3_index = c3_index % a3;
|
int c0_index = i / stride0;
|
||||||
|
int c1_index = (i % stride0) / stride1;
|
||||||
|
int c2_index = (i % stride1) / c3;
|
||||||
|
int c3_index = i % c3;
|
||||||
|
if (aIdx) {
|
||||||
|
|
||||||
int b0_index = c0_index % b0;
|
xIdx = (c0_index % a0) * a1 * a2 * a3 +
|
||||||
int b1_index = c1_index % b1;
|
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
|
||||||
int b2_index = c2_index % b2;
|
c3_index % a3;
|
||||||
int b3_index = c3_index % b3;
|
}
|
||||||
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
if (bIdx) {
|
||||||
a2_index * a3 + a3_index] /
|
|
||||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
yIdx = (c0_index % b0) * b1 * b2 * b3 +
|
||||||
b2_index * b3 + b3_index];
|
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
|
||||||
|
c3_index % b3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
((T *)z)[i] = ((T *)x)[xIdx] / ((T *)y)[yIdx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,28 +49,36 @@ __global__ void _add_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||||
int c1, int c2, int c3) {
|
int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
|
||||||
int n = c0 * c1 * c2 * c3;
|
|
||||||
|
|
||||||
for (int i = index; i < n; i += stride) {
|
int stride1 = c2 * c3;
|
||||||
int c0_index = i / (c1 * c2 * c3);
|
int stride0 = c1 * stride1;
|
||||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
int n = c0 * stride0;
|
||||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||||
|
|
||||||
int a0_index = c0_index % a0;
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
int a1_index = c1_index % a1;
|
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||||
int a2_index = c2_index % a2;
|
if (aIdx || bIdx) {
|
||||||
int a3_index = c3_index % a3;
|
int c0_index = i / stride0;
|
||||||
|
int c1_index = (i % stride0) / stride1;
|
||||||
|
int c2_index = (i % stride1) / c3;
|
||||||
|
int c3_index = i % c3;
|
||||||
|
if (aIdx) {
|
||||||
|
|
||||||
int b0_index = c0_index % b0;
|
xIdx = (c0_index % a0) * a1 * a2 * a3 +
|
||||||
int b1_index = c1_index % b1;
|
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
|
||||||
int b2_index = c2_index % b2;
|
c3_index % a3;
|
||||||
int b3_index = c3_index % b3;
|
}
|
||||||
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
if (bIdx) {
|
||||||
a2_index * a3 + a3_index] +
|
|
||||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
yIdx = (c0_index % b0) * b1 * b2 * b3 +
|
||||||
b2_index * b3 + b3_index];
|
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
|
||||||
|
c3_index % b3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
((T *)z)[i] = ((T *)x)[xIdx] + ((T *)y)[yIdx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,29 +87,36 @@ __global__ void _pow_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||||
int c1, int c2, int c3) {
|
int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
|
||||||
int n = c0 * c1 * c2 * c3;
|
|
||||||
|
|
||||||
for (int i = index; i < n; i += stride) {
|
int stride1 = c2 * c3;
|
||||||
int c0_index = i / (c1 * c2 * c3);
|
int stride0 = c1 * stride1;
|
||||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
int n = c0 * stride0;
|
||||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||||
|
|
||||||
int a0_index = c0_index % a0;
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
int a1_index = c1_index % a1;
|
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||||
int a2_index = c2_index % a2;
|
if (aIdx || bIdx) {
|
||||||
int a3_index = c3_index % a3;
|
int c0_index = i / stride0;
|
||||||
|
int c1_index = (i % stride0) / stride1;
|
||||||
|
int c2_index = (i % stride1) / c3;
|
||||||
|
int c3_index = i % c3;
|
||||||
|
if (aIdx) {
|
||||||
|
|
||||||
int b0_index = c0_index % b0;
|
xIdx = (c0_index % a0) * a1 * a2 * a3 +
|
||||||
int b1_index = c1_index % b1;
|
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
|
||||||
int b2_index = c2_index % b2;
|
c3_index % a3;
|
||||||
int b3_index = c3_index % b3;
|
}
|
||||||
((T *)z)[i] =
|
if (bIdx) {
|
||||||
pow(((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
|
||||||
a2_index * a3 + a3_index],
|
yIdx = (c0_index % b0) * b1 * b2 * b3 +
|
||||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
|
||||||
b2_index * b3 + b3_index]);
|
c3_index % b3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
((T *)z)[i] = pow(((T *)x)[xIdx], ((T *)y)[yIdx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,31 +125,36 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||||
int c1, int c2, int c3) {
|
int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
|
||||||
int n = c0 * c1 * c2 * c3;
|
|
||||||
|
|
||||||
for (int i = index; i < n; i += stride) {
|
int stride1 = c2 * c3;
|
||||||
int c0_index = i / (c1 * c2 * c3);
|
int stride0 = c1 * stride1;
|
||||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
int n = c0 * stride0;
|
||||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
int yIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||||
|
|
||||||
int a0_index = c0_index % a0;
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
int a1_index = c1_index % a1;
|
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||||
int a2_index = c2_index % a2;
|
if (aIdx || bIdx) {
|
||||||
int a3_index = c3_index % a3;
|
int c0_index = i / stride0;
|
||||||
|
int c1_index = (i % stride0) / stride1;
|
||||||
|
int c2_index = (i % stride1) / c3;
|
||||||
|
int c3_index = i % c3;
|
||||||
|
if (aIdx) {
|
||||||
|
|
||||||
int b0_index = c0_index % b0;
|
xIdx = (c0_index % a0) * a1 * a2 * a3 +
|
||||||
int b1_index = c1_index % b1;
|
(c1_index % a1) * a2 * a3 + (c2_index % a2) * a3 +
|
||||||
int b2_index = c2_index % b2;
|
c3_index % a3;
|
||||||
int b3_index = c3_index % b3;
|
}
|
||||||
((bool *)z)[i] =
|
if (bIdx) {
|
||||||
((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
|
||||||
a2_index * a3 + a3_index] <
|
yIdx = (c0_index % b0) * b1 * b2 * b3 +
|
||||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
(c1_index % b1) * b2 * b3 + (c2_index % b2) * b3 +
|
||||||
b2_index * b3 + b3_index]
|
c3_index % b3;
|
||||||
? true
|
}
|
||||||
: false;
|
}
|
||||||
|
((bool *)z)[i] = ((T *)x)[xIdx] < ((T *)y)[yIdx] ? true : false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,7 +204,6 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
default: \
|
default: \
|
||||||
IT_TODO_HALT(); \
|
IT_TODO_HALT(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
__global__ void _div_const_kernel(void const *__restrict__ x,
|
__global__ void _div_const_kernel(void const *__restrict__ x,
|
||||||
void const *__restrict__ y,
|
void const *__restrict__ y,
|
||||||
|
@ -269,7 +296,8 @@ void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize =
|
||||||
|
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
|
||||||
SWITCH_DTYPE(div, dType)
|
SWITCH_DTYPE(div, dType)
|
||||||
}
|
}
|
||||||
void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
|
@ -278,7 +306,8 @@ void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize =
|
||||||
|
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
|
||||||
SWITCH_DTYPE(add, dType)
|
SWITCH_DTYPE(add, dType)
|
||||||
}
|
}
|
||||||
void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
|
@ -286,7 +315,8 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
int c3) {
|
int c3) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize =
|
||||||
|
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
|
||||||
if (dType == 1) {
|
if (dType == 1) {
|
||||||
_pow_kernel<float>
|
_pow_kernel<float>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
|
@ -324,7 +354,8 @@ void less_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
int c3) {
|
int c3) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize =
|
||||||
|
(num + repeat * block_work_size() - 1) / (repeat * block_work_size());
|
||||||
SWITCH_DTYPE(less, dType)
|
SWITCH_DTYPE(less, dType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,22 +12,33 @@ class ExpandCuda : public CudaKernelWithoutConfig {
|
||||||
|
|
||||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
const auto &in_Shape = op->getInputs(0)->getDims(); // input shape
|
auto a_dim = op->getInputs(0)->getDims();
|
||||||
const auto &out_Shape = op->getShape(); // output shape
|
auto b_dim = op->getOutput()->getDims(); // output shape
|
||||||
|
|
||||||
SmallArray inputShape, outputShape;
|
|
||||||
int nDims = op->getInputs(0)->getDims().size();
|
|
||||||
|
|
||||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
|
||||||
int outputsize = 1; // the length of the output vector after flatten
|
|
||||||
for (int i = 0; i < nDims; ++i) {
|
|
||||||
outputShape.data[i] = out_Shape[i];
|
|
||||||
inputShape.data[i] = in_Shape[i];
|
|
||||||
outputsize *= out_Shape[i];
|
|
||||||
}
|
|
||||||
const int dType = op->getDType().getIndex();
|
const int dType = op->getDType().getIndex();
|
||||||
expandKernel(dType, inputData, outputData, nDims, outputsize,
|
if (a_dim.size() > 4 || b_dim.size() > 4) {
|
||||||
inputShape, outputShape);
|
SmallArray inputShape, outputShape;
|
||||||
|
int nDims = op->getInputs(0)->getDims().size();
|
||||||
|
|
||||||
|
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||||
|
int outputsize = 1; // the length of the output vector after flatten
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
outputShape.data[i] = b_dim[i];
|
||||||
|
inputShape.data[i] = a_dim[i];
|
||||||
|
outputsize *= b_dim[i];
|
||||||
|
}
|
||||||
|
const int dType = op->getDType().getIndex();
|
||||||
|
expandKernel(dType, inputData, outputData, nDims, outputsize,
|
||||||
|
inputShape, outputShape);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
int a[4] = {1, 1, 1, 1};
|
||||||
|
int b[4] = {1, 1, 1, 1};
|
||||||
|
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
|
||||||
|
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
||||||
|
expandKernel(dType, inputData, outputData, a[0], a[1], a[2], a[3],
|
||||||
|
b[0], b[1], b[2], b[3]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,31 @@
|
||||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||||
constexpr int thread_work_size() { return 4; }
|
constexpr int thread_work_size() { return 4; }
|
||||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||||
|
const int repeat = 1;
|
||||||
|
template <class T>
|
||||||
|
__global__ void _expandKernel(void *input, void *output, int a0, int a1, int a2,
|
||||||
|
int a3, int b0, int b1, int b2, int b3) {
|
||||||
|
|
||||||
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
int stride1 = b2 * b3;
|
||||||
|
int stride0 = b1 * stride1;
|
||||||
|
int n = b0 * stride0;
|
||||||
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int xIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
|
if (aIdx) {
|
||||||
|
int b0_index = i / stride0;
|
||||||
|
int b1_index = (i % stride0) / stride1;
|
||||||
|
int b2_index = (i % stride1) / b3;
|
||||||
|
int b3_index = i % b3;
|
||||||
|
xIdx = (b0_index % a0) * a1 * a2 * a3 + (b1_index % a1) * a2 * a3 +
|
||||||
|
(b2_index % a2) * a3 + b3_index % a3;
|
||||||
|
}
|
||||||
|
((T *)output)[i] = ((T *)input)[xIdx];
|
||||||
|
}
|
||||||
|
}
|
||||||
template <class T>
|
template <class T>
|
||||||
__global__ void _expandKernel(void *input, void *output, int nDims,
|
__global__ void _expandKernel(void *input, void *output, int nDims,
|
||||||
int outputsize, infini::SmallArray inputShape,
|
int outputsize, infini::SmallArray inputShape,
|
||||||
|
@ -38,7 +62,6 @@ __global__ void _expandKernel(void *input, void *output, int nDims,
|
||||||
((T *)output)[outputIdx] = ((T *)input)[inputIdx];
|
((T *)output)[outputIdx] = ((T *)input)[inputIdx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
static __global__ void _expandRowKernel(void *__restrict__ dst,
|
static __global__ void _expandRowKernel(void *__restrict__ dst,
|
||||||
void const *__restrict__ src) {
|
void const *__restrict__ src) {
|
||||||
|
@ -50,9 +73,9 @@ static __global__ void _expandRowKernel(void *__restrict__ dst,
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
#define CASE(T) \
|
#define CASE(T) \
|
||||||
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize, \
|
_expandKernel<DT_CUDA<T>::t> \
|
||||||
0, CUDAStream::getCurrentStream()>>>( \
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||||
input, output, nDims, outputsize, inputShape, outputShape);
|
input, output, a0, a1, a2, a3, b0, b1, b2, b3);
|
||||||
|
|
||||||
#define SWITCH_DTYPE(DTYPE) \
|
#define SWITCH_DTYPE(DTYPE) \
|
||||||
switch (DTYPE) { \
|
switch (DTYPE) { \
|
||||||
|
@ -96,14 +119,56 @@ namespace infini {
|
||||||
IT_TODO_HALT(); \
|
IT_TODO_HALT(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void expandKernel(int dType, void *input, void *output, int a0, int a1, int a2,
|
||||||
|
int a3, int b0, int b1, int b2, int b3) {
|
||||||
|
int blocksize = block_work_size();
|
||||||
|
int outputsize = b0 * b1 * b2 * b3;
|
||||||
|
int gridsize = (outputsize + repeat * block_work_size() - 1) /
|
||||||
|
(repeat * block_work_size());
|
||||||
|
SWITCH_DTYPE(dType)
|
||||||
|
}
|
||||||
|
#define CASECurrency(T) \
|
||||||
|
_expandKernel<DT_CUDA<T>::t> \
|
||||||
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||||
|
input, output, nDims, outputsize, inputShape, outputShape);
|
||||||
|
|
||||||
|
#define SWITCHCurrency_DTYPE(DTYPE) \
|
||||||
|
switch (DTYPE) { \
|
||||||
|
case 1: \
|
||||||
|
CASECurrency(1) break; \
|
||||||
|
case 2: \
|
||||||
|
CASECurrency(2) break; \
|
||||||
|
case 3: \
|
||||||
|
CASECurrency(3) break; \
|
||||||
|
case 4: \
|
||||||
|
CASECurrency(4) break; \
|
||||||
|
case 5: \
|
||||||
|
CASECurrency(5) break; \
|
||||||
|
case 6: \
|
||||||
|
CASECurrency(6) break; \
|
||||||
|
case 7: \
|
||||||
|
CASECurrency(7) break; \
|
||||||
|
case 10: \
|
||||||
|
CASECurrency(10) break; \
|
||||||
|
case 11: \
|
||||||
|
CASECurrency(11) break; \
|
||||||
|
case 12: \
|
||||||
|
CASECurrency(12) break; \
|
||||||
|
case 13: \
|
||||||
|
CASECurrency(13) break; \
|
||||||
|
case 16: \
|
||||||
|
CASECurrency(16) break; \
|
||||||
|
default: \
|
||||||
|
IT_TODO_HALT(); \
|
||||||
|
}
|
||||||
|
|
||||||
void expandKernel(int dType, void *input, void *output, int nDims,
|
void expandKernel(int dType, void *input, void *output, int nDims,
|
||||||
int outputsize, SmallArray inputShape,
|
int outputsize, SmallArray inputShape,
|
||||||
SmallArray outputShape) {
|
SmallArray outputShape) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
|
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
|
||||||
SWITCH_DTYPE(dType)
|
SWITCHCurrency_DTYPE(dType)
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CASE_ROW(T) \
|
#define CASE_ROW(T) \
|
||||||
_expandRowKernel<float> \
|
_expandRowKernel<float> \
|
||||||
<<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input);
|
<<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input);
|
||||||
|
@ -150,7 +215,8 @@ void expandKernel(int dType, void *input, void *output, int nDims,
|
||||||
IT_TODO_HALT(); \
|
IT_TODO_HALT(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optimization for expanding a row vector. The row length must be a multiple of 32
|
// Optimization for expanding a row vector. The row length must be a multiple of
|
||||||
|
// 32
|
||||||
void expandRowKernel(int dType, void *input, void *output, int n_rows,
|
void expandRowKernel(int dType, void *input, void *output, int n_rows,
|
||||||
int row_len) {
|
int row_len) {
|
||||||
// Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32
|
// Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32
|
||||||
|
@ -160,7 +226,8 @@ void expandRowKernel(int dType, void *input, void *output, int n_rows,
|
||||||
// block: b x 32
|
// block: b x 32
|
||||||
auto c = row_len / 32, b = c;
|
auto c = row_len / 32, b = c;
|
||||||
if (b > 32) {
|
if (b > 32) {
|
||||||
for (b = 32; c % b != 0; --b);
|
for (b = 32; c % b != 0; --b)
|
||||||
|
;
|
||||||
}
|
}
|
||||||
auto a = c / b;
|
auto a = c / b;
|
||||||
dim3 grid(a, n_rows), block(32, b);
|
dim3 grid(a, n_rows), block(32, b);
|
||||||
|
|
|
@ -87,20 +87,7 @@ class matmulCublas : public Kernel {
|
||||||
beta_naive = 1.f;
|
beta_naive = 1.f;
|
||||||
auto inC = op->getInputs(2);
|
auto inC = op->getInputs(2);
|
||||||
auto out = op->getOutput();
|
auto out = op->getOutput();
|
||||||
SmallArray inputShape, outputShape;
|
|
||||||
int nDims = out->getRank();
|
|
||||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
|
||||||
// FIXME(constroy): use size_t for outputsize.
|
|
||||||
int outputsize = 1; // the length of the output vector after flatten
|
|
||||||
int offset = nDims - inC->getRank();
|
|
||||||
for (int i = 0; i < offset; ++i)
|
|
||||||
inputShape.data[i] = 1;
|
|
||||||
for (int i = 0; i < nDims; ++i) {
|
|
||||||
outputShape.data[i] = out->getDims()[i];
|
|
||||||
outputsize *= outputShape.data[i];
|
|
||||||
if (i >= offset)
|
|
||||||
inputShape.data[i] = inC->getDims()[i - offset];
|
|
||||||
}
|
|
||||||
const int dType = dataType.getIndex();
|
const int dType = dataType.getIndex();
|
||||||
|
|
||||||
// Bias in linear layer is row vector of (1,n), n is the number of
|
// Bias in linear layer is row vector of (1,n), n is the number of
|
||||||
|
@ -111,9 +98,40 @@ class matmulCublas : public Kernel {
|
||||||
out->size() / inC->getDims()[0],
|
out->size() / inC->getDims()[0],
|
||||||
inC->getDims()[0]);
|
inC->getDims()[0]);
|
||||||
} else {
|
} else {
|
||||||
expandKernel(dType, inC->getRawDataPtr<void *>(),
|
auto a_dim = out->getDims();
|
||||||
out->getRawDataPtr<void *>(), nDims, outputsize,
|
auto b_dim = inC->getDims(); // output shape
|
||||||
inputShape, outputShape);
|
|
||||||
|
if (a_dim.size() > 4 || b_dim.size() > 4) {
|
||||||
|
SmallArray inputShape, outputShape;
|
||||||
|
int nDims = out->getRank();
|
||||||
|
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||||
|
// FIXME(constroy): use size_t for outputsize.
|
||||||
|
int outputsize =
|
||||||
|
1; // the length of the output vector after flatten
|
||||||
|
int offset = nDims - inC->getRank();
|
||||||
|
for (int i = 0; i < offset; ++i)
|
||||||
|
inputShape.data[i] = 1;
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
outputShape.data[i] = out->getDims()[i];
|
||||||
|
outputsize *= outputShape.data[i];
|
||||||
|
if (i >= offset)
|
||||||
|
inputShape.data[i] = inC->getDims()[i - offset];
|
||||||
|
}
|
||||||
|
expandKernel(dType, inC->getRawDataPtr<void *>(),
|
||||||
|
out->getRawDataPtr<void *>(), nDims,
|
||||||
|
outputsize, inputShape, outputShape);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
int a[4] = {1, 1, 1, 1};
|
||||||
|
int b[4] = {1, 1, 1, 1};
|
||||||
|
std::copy(a_dim.begin(), a_dim.end(),
|
||||||
|
a + (4 - a_dim.size()));
|
||||||
|
std::copy(b_dim.begin(), b_dim.end(),
|
||||||
|
b + (4 - b_dim.size()));
|
||||||
|
expandKernel(dType, inC->getRawDataPtr<void *>(),
|
||||||
|
out->getRawDataPtr<void *>(), a[0], a[1], a[2],
|
||||||
|
a[3], b[0], b[1], b[2], b[3]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO:use compute type
|
// TODO:use compute type
|
||||||
|
|
|
@ -16,31 +16,57 @@ class TransposeCuda : public CudaKernelWithoutConfig {
|
||||||
void *const outputData = output->getRawDataPtr<void *>();
|
void *const outputData = output->getRawDataPtr<void *>();
|
||||||
const auto &inputShape = input->getDims();
|
const auto &inputShape = input->getDims();
|
||||||
const auto &outputShape = output->getDims();
|
const auto &outputShape = output->getDims();
|
||||||
|
const int dType = op->getDType().getIndex();
|
||||||
const auto &perm = op->getPermute();
|
|
||||||
int size = input->size();
|
int size = input->size();
|
||||||
int nDims = input->getDims().size();
|
int nDims = input->getDims().size();
|
||||||
|
//----------------
|
||||||
// Compute strides
|
bool condition = true;
|
||||||
SmallArray strides, buffer;
|
int gnum = 0;
|
||||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
for (int i = 0; i < nDims; i++) {
|
||||||
int curStride = 1;
|
if (inputShape[i] > 1) {
|
||||||
for (int i = nDims - 1; i >= 0; --i) {
|
while (gnum < nDims) {
|
||||||
buffer.data[i] = curStride;
|
if (outputShape[gnum] > 1) {
|
||||||
curStride *= inputShape[i];
|
gnum += 1;
|
||||||
}
|
break;
|
||||||
for (int i = 0; i < nDims; ++i) {
|
} else {
|
||||||
strides.data[i] = buffer.data[perm[i]];
|
gnum += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (inputShape[i] != outputShape[gnum - 1]) {
|
||||||
|
condition = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
//----------------
|
||||||
|
if (condition) {
|
||||||
|
cudaMemcpyAsync(outputData, inputData, op->getInputs(0)->getBytes(),
|
||||||
|
cudaMemcpyDeviceToDevice,
|
||||||
|
CUDAStream::getCurrentStream());
|
||||||
|
|
||||||
SmallArray outputDims;
|
} else {
|
||||||
for (int i = 0; i < nDims; ++i) {
|
const auto &perm = op->getPermute();
|
||||||
outputDims.data[i] = outputShape[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
const int dType = op->getDType().getIndex();
|
// Compute strides
|
||||||
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
|
SmallArray strides, buffer;
|
||||||
outputDims);
|
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||||
|
int curStride = 1;
|
||||||
|
for (int i = nDims - 1; i >= 0; --i) {
|
||||||
|
buffer.data[i] = curStride;
|
||||||
|
curStride *= inputShape[i];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
strides.data[i] = buffer.data[perm[i]];
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallArray outputDims;
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
outputDims.data[i] = outputShape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
|
||||||
|
outputDims);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -24,8 +24,8 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
|
||||||
}
|
}
|
||||||
#define CASE(T) \
|
#define CASE(T) \
|
||||||
_transpose_kernel<DT_CUDA<T>::t> \
|
_transpose_kernel<DT_CUDA<T>::t> \
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||||
(input, output, nDims, size, strides, outputShape);
|
input, output, nDims, size, strides, outputShape);
|
||||||
|
|
||||||
#define SWITCH_DTYPE(DTYPE) \
|
#define SWITCH_DTYPE(DTYPE) \
|
||||||
switch (DTYPE) { \
|
switch (DTYPE) { \
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
#include "operators/where.h"
|
#include "operators/where.h"
|
||||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
#include "cuda/cuda_where.h"
|
#include "cuda/cuda_where.h"
|
||||||
#include "utils/operator_utils.h"
|
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
@ -15,39 +15,50 @@ class WhereCuda : public CudaKernelWithoutConfig {
|
||||||
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
|
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
const auto &opInputXShape = op->getInputs(0)->getDims();
|
|
||||||
const auto &opInputYShape = op->getInputs(1)->getDims();
|
|
||||||
const auto &opConditionShape = op->getInputs(2)->getDims();
|
|
||||||
const auto &opOutputShape = op->getOutput()->getDims();
|
|
||||||
|
|
||||||
const int xSize = op->getInputs(0)->getRank();
|
auto a_dim = op->getInputs(0)->getDims();
|
||||||
const int ySize = op->getInputs(1)->getRank();
|
auto b_dim = op->getInputs(1)->getDims();
|
||||||
const int cSize = op->getInputs(2)->getRank();
|
auto c_dim = op->getInputs(2)->getDims();
|
||||||
|
auto d_dim = op->getOutput()->getDims();
|
||||||
|
const int dTypeIndex = op->getDType().getIndex();
|
||||||
|
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4 ||
|
||||||
|
d_dim.size() > 4) {
|
||||||
|
const int xSize = op->getInputs(0)->getRank();
|
||||||
|
const int ySize = op->getInputs(1)->getRank();
|
||||||
|
const int cSize = op->getInputs(2)->getRank();
|
||||||
|
|
||||||
int nDims = op->getOutput()->getDims().size();
|
int nDims = op->getOutput()->getDims().size();
|
||||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||||
int outputsize = 1;
|
int outputsize = 1;
|
||||||
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
|
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
|
||||||
for (int i = nDims - 1; i >= 0; --i) {
|
for (int i = nDims - 1; i >= 0; --i) {
|
||||||
outputShape.data[i] = opOutputShape[i];
|
outputShape.data[i] = d_dim[i];
|
||||||
outputsize *= outputShape.data[i];
|
outputsize *= outputShape.data[i];
|
||||||
|
}
|
||||||
|
broadcastShape(a_dim, inputXShape, nDims, xSize);
|
||||||
|
broadcastShape(b_dim, inputYShape, nDims, ySize);
|
||||||
|
broadcastShape(c_dim, conditionShape, nDims, cSize);
|
||||||
|
whereKernel(dTypeIndex, inputXData, inputYData,
|
||||||
|
(uint8_t *)conditionData, outputData, nDims, outputsize,
|
||||||
|
inputXShape, inputYShape, conditionShape, outputShape,
|
||||||
|
xSize, ySize, cSize);
|
||||||
}
|
}
|
||||||
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
|
|
||||||
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
|
|
||||||
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
|
|
||||||
|
|
||||||
if (op->getDType() == DataType::Float32) {
|
else {
|
||||||
whereKernel((float *)inputXData, (float *)inputYData,
|
int a[4] = {1, 1, 1, 1};
|
||||||
(uint8_t *)conditionData, (float *)outputData, nDims,
|
int b[4] = {1, 1, 1, 1};
|
||||||
outputsize, inputXShape, inputYShape, conditionShape,
|
int c[4] = {1, 1, 1, 1};
|
||||||
outputShape, xSize, ySize, cSize);
|
int d[4] = {1, 1, 1, 1};
|
||||||
} else if (op->getDType() == DataType::Float16) {
|
|
||||||
whereKernel((half *)inputXData, (half *)inputYData,
|
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
|
||||||
(uint8_t *)conditionData, (half *)outputData, nDims,
|
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
||||||
outputsize, inputXShape, inputYShape, conditionShape,
|
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
|
||||||
outputShape, xSize, ySize, cSize);
|
std::copy(d_dim.begin(), d_dim.end(), d + (4 - d_dim.size()));
|
||||||
} else {
|
|
||||||
IT_ASSERT(false);
|
whereKernel(dTypeIndex, inputXData, inputYData,
|
||||||
|
(uint8_t *)conditionData, outputData, a[0], a[1], a[2],
|
||||||
|
a[3], b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3],
|
||||||
|
d[0], d[1], d[2], d[3]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,6 +1,109 @@
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
#include "utils/small_array.h"
|
#include "utils/small_array.h"
|
||||||
|
const int repeat = 1;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void
|
||||||
|
_whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
|
||||||
|
int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3,
|
||||||
|
int c0, int c1, int c2, int c3, int d0, int d1, int d2, int d3) {
|
||||||
|
|
||||||
|
int stride1 = d2 * d3;
|
||||||
|
int stride0 = d1 * stride1;
|
||||||
|
int n = d0 * stride0;
|
||||||
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
int end = (repeat * index + repeat < n ? repeat * index + repeat : n);
|
||||||
|
for (int i = repeat * index; i < end; i++) {
|
||||||
|
int inputXIdx = (a0 * a1 * a2 * a3 == n ? i : 0);
|
||||||
|
int inputYIdx = (b0 * b1 * b2 * b3 == n ? i : 0);
|
||||||
|
int conditionIdx = (c0 * c1 * c2 * c3 == n ? i : 0);
|
||||||
|
|
||||||
|
bool aIdx = (a0 * a1 * a2 * a3 < n && a0 * a1 * a2 * a3 > 1);
|
||||||
|
bool bIdx = (b0 * b1 * b2 * b3 < n && b0 * b1 * b2 * b3 > 1);
|
||||||
|
bool cIdx = (c0 * c1 * c2 * c3 < n && c0 * c1 * c2 * c3 > 1);
|
||||||
|
if (aIdx || bIdx || cIdx) {
|
||||||
|
int d0_index = i / stride0;
|
||||||
|
int d1_index = (i % stride0) / stride1;
|
||||||
|
int d2_index = (i % stride1) / d3;
|
||||||
|
int d3_index = i % d3;
|
||||||
|
if (aIdx) {
|
||||||
|
int a0_index = d0_index % a0;
|
||||||
|
int a1_index = d1_index % a1;
|
||||||
|
int a2_index = d2_index % a2;
|
||||||
|
int a3_index = d3_index % a3;
|
||||||
|
inputXIdx = a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||||
|
a2_index * a3 + a3_index;
|
||||||
|
}
|
||||||
|
if (bIdx) {
|
||||||
|
int b0_index = d0_index % b0;
|
||||||
|
int b1_index = d1_index % b1;
|
||||||
|
int b2_index = d2_index % b2;
|
||||||
|
int b3_index = d3_index % b3;
|
||||||
|
inputYIdx = b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||||
|
b2_index * b3 + b3_index;
|
||||||
|
}
|
||||||
|
if (cIdx) {
|
||||||
|
int c0_index = d0_index % c0;
|
||||||
|
int c1_index = d1_index % c1;
|
||||||
|
int c2_index = d2_index % c2;
|
||||||
|
int c3_index = d3_index % c3;
|
||||||
|
conditionIdx = c0_index * c1 * c2 * c3 + c1_index * c2 * c3 +
|
||||||
|
c2_index * c3 + c3_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
((T *)output)[i] = condition[conditionIdx] ? ((T *)inputX)[inputXIdx]
|
||||||
|
: ((T *)inputY)[inputYIdx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#define CASE(T) \
|
||||||
|
_whereKernel<DT_CUDA<T>::t> \
|
||||||
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||||
|
inputX, inputY, condition, output, a0, a1, a2, a3, b0, b1, b2, b3, \
|
||||||
|
c0, c1, c2, c3, d0, d1, d2, d3);
|
||||||
|
|
||||||
|
#define SWITCH_DTYPE(DTYPE) \
|
||||||
|
switch (DTYPE) { \
|
||||||
|
case 1: \
|
||||||
|
CASE(1) \
|
||||||
|
break; \
|
||||||
|
case 2: \
|
||||||
|
CASE(2) \
|
||||||
|
break; \
|
||||||
|
case 3: \
|
||||||
|
CASE(3) \
|
||||||
|
break; \
|
||||||
|
case 4: \
|
||||||
|
CASE(4) \
|
||||||
|
break; \
|
||||||
|
case 5: \
|
||||||
|
CASE(5) \
|
||||||
|
break; \
|
||||||
|
case 6: \
|
||||||
|
CASE(6) \
|
||||||
|
break; \
|
||||||
|
case 7: \
|
||||||
|
CASE(7) \
|
||||||
|
break; \
|
||||||
|
case 10: \
|
||||||
|
CASE(10) \
|
||||||
|
break; \
|
||||||
|
case 11: \
|
||||||
|
CASE(11) \
|
||||||
|
break; \
|
||||||
|
case 12: \
|
||||||
|
CASE(12) \
|
||||||
|
break; \
|
||||||
|
case 13: \
|
||||||
|
CASE(13) \
|
||||||
|
break; \
|
||||||
|
case 16: \
|
||||||
|
CASE(16) \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
IT_TODO_HALT(); \
|
||||||
|
}
|
||||||
__device__ int inferIndex(infini::SmallArray inputShape,
|
__device__ int inferIndex(infini::SmallArray inputShape,
|
||||||
infini::SmallArray outputShape, int nDims, int size,
|
infini::SmallArray outputShape, int nDims, int size,
|
||||||
int outputIdx) {
|
int outputIdx) {
|
||||||
|
@ -19,11 +122,10 @@ __device__ int inferIndex(infini::SmallArray inputShape,
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void
|
__global__ void
|
||||||
_whereKernel(const T *inputX, const T *inputY, const uint8_t *condition,
|
_whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
|
||||||
T *output, int nDims, int outputsize,
|
int nDims, int outputsize, infini::SmallArray inputXShape,
|
||||||
infini::SmallArray inputXShape, infini::SmallArray inputYShape,
|
infini::SmallArray inputYShape, infini::SmallArray conditionShape,
|
||||||
infini::SmallArray conditionShape, infini::SmallArray outputShape,
|
infini::SmallArray outputShape, int xSize, int ySize, int cSize) {
|
||||||
int xSize, int ySize, int cSize) {
|
|
||||||
|
|
||||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (outputIdx < outputsize) {
|
if (outputIdx < outputsize) {
|
||||||
|
@ -35,14 +137,74 @@ _whereKernel(const T *inputX, const T *inputY, const uint8_t *condition,
|
||||||
int inputYIdx =
|
int inputYIdx =
|
||||||
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
|
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
|
||||||
|
|
||||||
output[outputIdx] =
|
((T *)output)[outputIdx] = condition[conditionIdx]
|
||||||
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
|
? ((T *)inputX)[inputXIdx]
|
||||||
|
: ((T *)inputY)[inputYIdx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#define CASECurrency(T) \
|
||||||
|
_whereKernel<DT_CUDA<T>::t> \
|
||||||
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||||
|
inputX, inputY, condition, output, nDims, outputsize, inputXShape, \
|
||||||
|
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
||||||
|
|
||||||
|
#define SWITCHCurrency_DTYPE(DTYPE) \
|
||||||
|
switch (DTYPE) { \
|
||||||
|
case 1: \
|
||||||
|
CASECurrency(1) break; \
|
||||||
|
case 2: \
|
||||||
|
CASECurrency(2) break; \
|
||||||
|
case 3: \
|
||||||
|
CASECurrency(3) break; \
|
||||||
|
case 4: \
|
||||||
|
CASECurrency(4) break; \
|
||||||
|
case 5: \
|
||||||
|
CASECurrency(5) break; \
|
||||||
|
case 6: \
|
||||||
|
CASECurrency(6) break; \
|
||||||
|
case 7: \
|
||||||
|
CASECurrency(7) break; \
|
||||||
|
case 10: \
|
||||||
|
CASECurrency(10) break; \
|
||||||
|
case 11: \
|
||||||
|
CASECurrency(11) break; \
|
||||||
|
case 12: \
|
||||||
|
CASECurrency(12) break; \
|
||||||
|
case 13: \
|
||||||
|
CASECurrency(13) break; \
|
||||||
|
case 16: \
|
||||||
|
CASECurrency(16) break; \
|
||||||
|
default: \
|
||||||
|
IT_TODO_HALT(); \
|
||||||
|
}
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
|
||||||
|
const uint8_t *condition, void *output, int a0, int a1, int a2,
|
||||||
|
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||||
|
int c3, int d0, int d1, int d2, int d3) {
|
||||||
|
int blocksize;
|
||||||
|
int outputsize = d0 * d1 * d2 * d3;
|
||||||
|
if (outputsize > 511 * repeat) {
|
||||||
|
blocksize = 1024;
|
||||||
|
} else if (outputsize > 255 * repeat) {
|
||||||
|
blocksize = 512;
|
||||||
|
} else if (outputsize > 127 * repeat) {
|
||||||
|
blocksize = 256;
|
||||||
|
} else if (outputsize > 63 * repeat) {
|
||||||
|
blocksize = 128;
|
||||||
|
} else if (outputsize > 31 * repeat) {
|
||||||
|
blocksize = 64;
|
||||||
|
} else {
|
||||||
|
blocksize = 32;
|
||||||
|
}
|
||||||
|
int gridsize = (outputsize + repeat * blocksize - 1) / (repeat * blocksize);
|
||||||
|
|
||||||
|
SWITCH_DTYPE(dTypeIndex)
|
||||||
|
}
|
||||||
|
|
||||||
namespace infini {
|
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
|
||||||
void whereKernel(const float *inputX, const float *inputY,
|
const uint8_t *condition, void *output, int nDims,
|
||||||
const uint8_t *condition, float *output, int nDims,
|
|
||||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||||
int ySize, int cSize) {
|
int ySize, int cSize) {
|
||||||
|
@ -61,34 +223,8 @@ void whereKernel(const float *inputX, const float *inputY,
|
||||||
blocksize = 32;
|
blocksize = 32;
|
||||||
}
|
}
|
||||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||||
_whereKernel<float>
|
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
SWITCHCurrency_DTYPE(dTypeIndex)
|
||||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
|
||||||
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
|
||||||
}
|
|
||||||
void whereKernel(const half *inputX, const half *inputY,
|
|
||||||
const uint8_t *condition, half *output, int nDims,
|
|
||||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
|
||||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
|
||||||
int ySize, int cSize) {
|
|
||||||
int blocksize;
|
|
||||||
if (outputsize > 511) {
|
|
||||||
blocksize = 1024;
|
|
||||||
} else if (outputsize > 255) {
|
|
||||||
blocksize = 512;
|
|
||||||
} else if (outputsize > 127) {
|
|
||||||
blocksize = 256;
|
|
||||||
} else if (outputsize > 63) {
|
|
||||||
blocksize = 128;
|
|
||||||
} else if (outputsize > 31) {
|
|
||||||
blocksize = 64;
|
|
||||||
} else {
|
|
||||||
blocksize = 32;
|
|
||||||
}
|
|
||||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
|
||||||
_whereKernel<half>
|
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
|
||||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
|
||||||
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -84,6 +84,17 @@ void test_whereFp16(
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CUDA_WhereFp32, run) {
|
TEST(CUDA_WhereFp32, run) {
|
||||||
|
test_whereFp32(
|
||||||
|
Shape{2, 2, 3, 1, 2},
|
||||||
|
vector<float>{0., 1., 2., 3., 4., 5., 6., 7.,
|
||||||
|
8., 9., 10., 11., 12., 13., 14., 15.,
|
||||||
|
16., 17., 18., 19., 20., 21., 22., 23.},
|
||||||
|
Shape{2, 2, 3, 1, 2},
|
||||||
|
vector<float>{0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||||
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.},
|
||||||
|
Shape{2, 3, 1, 2}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
||||||
|
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.,
|
||||||
|
0., 13., 14., 0., 0., 0., 18., 19., 0., 21., 22., 23.});
|
||||||
test_whereFp32(
|
test_whereFp32(
|
||||||
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||||
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
|
Loading…
Reference in New Issue