add currency operator

This commit is contained in:
xgqdut2016 2024-04-10 15:01:22 +08:00
parent 86133c8d0a
commit 73e3f1fc6f
8 changed files with 280 additions and 38 deletions

View File

@ -5,7 +5,9 @@
namespace infini {
void expandKernel(int dType, void *input, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3);
void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape,
SmallArray outputShape);
void expandRowKernel(int dType, void *input, void *output, int n_rows,
int row_len);
}; // namespace infini

View File

@ -1,10 +1,15 @@
#pragma once
#include "operators/unary.h"
#include "utils/small_array.h"
namespace infini {
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3, int d0, int d1, int d2, int d3);
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize);
}; // namespace infini

View File

@ -16,15 +16,29 @@ class ExpandCuda : public CudaKernelWithoutConfig {
auto b_dim = op->getOutput()->getDims(); // output shape
const int dType = op->getDType().getIndex();
if (a_dim.size() > 4 || b_dim.size() > 4)
IT_TODO_HALT();
if (a_dim.size() > 4 || b_dim.size() > 4) {
SmallArray inputShape, outputShape;
int nDims = op->getInputs(0)->getDims().size();
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
expandKernel(dType, inputData, outputData, a[0], a[1], a[2], a[3], b[0],
b[1], b[2], b[3]);
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1; // the length of the output vector after flatten
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = b_dim[i];
inputShape.data[i] = a_dim[i];
outputsize *= b_dim[i];
}
const int dType = op->getDType().getIndex();
expandKernel(dType, inputData, outputData, nDims, outputsize,
inputShape, outputShape);
} else {
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
expandKernel(dType, inputData, outputData, a[0], a[1], a[2], a[3],
b[0], b[1], b[2], b[3]);
}
}
};

View File

@ -31,7 +31,37 @@ __global__ void _expandKernel(void *input, void *output, int a0, int a1, int a2,
((T *)output)[i] = ((T *)input)[xIdx];
}
}
template <class T>
__global__ void _expandKernel(void *input, void *output, int nDims,
int outputsize, infini::SmallArray inputShape,
infini::SmallArray outputShape) {
int outputIdx =
blockIdx.x * blockDim.x + threadIdx.x; // i(JKS) + j(KS) + k(S) + s
if (outputIdx < outputsize) {
int inputIdx = 0; // record input index
int temp = 1; // stored S, KS, JKS, in order
int tmp = 1; // stored s,k,j,i in order
int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s
for (int i = nDims - 1; i >= 0; --i) {
if (i == 0) {
tmp = v; // i = outputIdx/(JKS)
} else {
tmp = v % outputShape.data[i]; // store s,k,j in order
}
if (inputShape.data[i] ==
1) { // if input shape = 1, the index only equal 0
inputIdx += 0;
} else {
inputIdx +=
tmp * temp; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp *= inputShape.data[i];
v = v / outputShape.data[i];
}
((T *)output)[outputIdx] = ((T *)input)[inputIdx];
}
}
template <class T>
static __global__ void _expandRowKernel(void *__restrict__ dst,
void const *__restrict__ src) {
@ -97,7 +127,48 @@ void expandKernel(int dType, void *input, void *output, int a0, int a1, int a2,
(repeat * block_work_size());
SWITCH_DTYPE(dType)
}
#define CASECurrency(T) \
_expandKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
input, output, nDims, outputsize, inputShape, outputShape);
#define SWITCHCurrency_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASECurrency(1) break; \
case 2: \
CASECurrency(2) break; \
case 3: \
CASECurrency(3) break; \
case 4: \
CASECurrency(4) break; \
case 5: \
CASECurrency(5) break; \
case 6: \
CASECurrency(6) break; \
case 7: \
CASECurrency(7) break; \
case 10: \
CASECurrency(10) break; \
case 11: \
CASECurrency(11) break; \
case 12: \
CASECurrency(12) break; \
case 13: \
CASECurrency(13) break; \
case 16: \
CASECurrency(16) break; \
default: \
IT_TODO_HALT(); \
}
void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape,
SmallArray outputShape) {
int blocksize = block_work_size();
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
SWITCHCurrency_DTYPE(dType)
}
#define CASE_ROW(T) \
_expandRowKernel<float> \
<<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input);

View File

@ -101,16 +101,37 @@ class matmulCublas : public Kernel {
auto a_dim = out->getDims();
auto b_dim = inC->getDims(); // output shape
if (a_dim.size() > 4 || b_dim.size() > 4)
IT_TODO_HALT();
if (a_dim.size() > 4 || b_dim.size() > 4) {
SmallArray inputShape, outputShape;
int nDims = out->getRank();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
// FIXME(constroy): use size_t for outputsize.
int outputsize =
1; // the length of the output vector after flatten
int offset = nDims - inC->getRank();
for (int i = 0; i < offset; ++i)
inputShape.data[i] = 1;
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = out->getDims()[i];
outputsize *= outputShape.data[i];
if (i >= offset)
inputShape.data[i] = inC->getDims()[i - offset];
}
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims,
outputsize, inputShape, outputShape);
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), a[0], a[1], a[2],
a[3], b[0], b[1], b[2], b[3]);
} else {
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(),
a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(),
b + (4 - b_dim.size()));
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), a[0], a[1], a[2],
a[3], b[0], b[1], b[2], b[3]);
}
}
}
// TODO:use compute type

View File

@ -20,26 +20,46 @@ class WhereCuda : public CudaKernelWithoutConfig {
auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getInputs(2)->getDims();
auto d_dim = op->getOutput()->getDims();
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4 ||
d_dim.size() > 4)
IT_TODO_HALT();
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
int c[4] = {1, 1, 1, 1};
int d[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
std::copy(d_dim.begin(), d_dim.end(), d + (4 - d_dim.size()));
const int dTypeIndex = op->getDType().getIndex();
whereKernel(dTypeIndex, inputXData, inputYData,
(uint8_t *)conditionData, outputData, a[0], a[1], a[2],
a[3], b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3], d[0],
d[1], d[2], d[3]);
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4 ||
d_dim.size() > 4) {
const int xSize = op->getInputs(0)->getRank();
const int ySize = op->getInputs(1)->getRank();
const int cSize = op->getInputs(2)->getRank();
int nDims = op->getOutput()->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1;
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
for (int i = nDims - 1; i >= 0; --i) {
outputShape.data[i] = d_dim[i];
outputsize *= outputShape.data[i];
}
broadcastShape(a_dim, inputXShape, nDims, xSize);
broadcastShape(b_dim, inputYShape, nDims, ySize);
broadcastShape(c_dim, conditionShape, nDims, cSize);
whereKernel(dTypeIndex, inputXData, inputYData,
(uint8_t *)conditionData, outputData, nDims, outputsize,
inputXShape, inputYShape, conditionShape, outputShape,
xSize, ySize, cSize);
}
else {
int a[4] = {1, 1, 1, 1};
int b[4] = {1, 1, 1, 1};
int c[4] = {1, 1, 1, 1};
int d[4] = {1, 1, 1, 1};
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
std::copy(d_dim.begin(), d_dim.end(), d + (4 - d_dim.size()));
whereKernel(dTypeIndex, inputXData, inputYData,
(uint8_t *)conditionData, outputData, a[0], a[1], a[2],
a[3], b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3],
d[0], d[1], d[2], d[3]);
}
}
};

View File

@ -1,5 +1,6 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
const int repeat = 1;
template <typename T>
@ -103,6 +104,79 @@ _whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
default: \
IT_TODO_HALT(); \
}
__device__ int inferIndex(infini::SmallArray inputShape,
infini::SmallArray outputShape, int nDims, int size,
int outputIdx) {
int inputIdx = 0;
int tempInput = 1;
int tempOutput = 1;
for (int i = nDims - 1; i >= nDims - size; --i) {
tempOutput = outputIdx % outputShape.data[i];
if (inputShape.data[i] != 1) {
inputIdx += tempInput * tempOutput;
}
tempInput *= inputShape.data[i];
outputIdx /= outputShape.data[i];
}
return inputIdx;
}
template <typename T>
__global__ void
_whereKernel(void *inputX, void *inputY, const uint8_t *condition, void *output,
int nDims, int outputsize, infini::SmallArray inputXShape,
infini::SmallArray inputYShape, infini::SmallArray conditionShape,
infini::SmallArray outputShape, int xSize, int ySize, int cSize) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < outputsize) {
int conditionIdx =
inferIndex(conditionShape, outputShape, nDims, cSize, outputIdx);
int inputXIdx =
inferIndex(inputXShape, outputShape, nDims, xSize, outputIdx);
int inputYIdx =
inferIndex(inputYShape, outputShape, nDims, ySize, outputIdx);
((T *)output)[outputIdx] = condition[conditionIdx]
? ((T *)inputX)[inputXIdx]
: ((T *)inputY)[inputYIdx];
}
}
#define CASECurrency(T) \
_whereKernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
inputX, inputY, condition, output, nDims, outputsize, inputXShape, \
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
#define SWITCHCurrency_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASECurrency(1) break; \
case 2: \
CASECurrency(2) break; \
case 3: \
CASECurrency(3) break; \
case 4: \
CASECurrency(4) break; \
case 5: \
CASECurrency(5) break; \
case 6: \
CASECurrency(6) break; \
case 7: \
CASECurrency(7) break; \
case 10: \
CASECurrency(10) break; \
case 11: \
CASECurrency(11) break; \
case 12: \
CASECurrency(12) break; \
case 13: \
CASECurrency(13) break; \
case 16: \
CASECurrency(16) break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
@ -129,4 +203,28 @@ void whereKernel(int dTypeIndex, void *inputX, void *inputY,
SWITCH_DTYPE(dTypeIndex)
}
void whereKernel(int dTypeIndex, void *inputX, void *inputY,
const uint8_t *condition, void *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize) {
int blocksize;
if (outputsize > 511) {
blocksize = 1024;
} else if (outputsize > 255) {
blocksize = 512;
} else if (outputsize > 127) {
blocksize = 256;
} else if (outputsize > 63) {
blocksize = 128;
} else if (outputsize > 31) {
blocksize = 64;
} else {
blocksize = 32;
}
int gridsize = (outputsize + blocksize - 1) / blocksize;
SWITCHCurrency_DTYPE(dTypeIndex)
}
} // namespace infini

View File

@ -84,6 +84,17 @@ void test_whereFp16(
}
TEST(CUDA_WhereFp32, run) {
test_whereFp32(
Shape{2, 2, 3, 1, 2},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7.,
8., 9., 10., 11., 12., 13., 14., 15.,
16., 17., 18., 19., 20., 21., 22., 23.},
Shape{2, 2, 3, 1, 2},
vector<float>{0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.},
Shape{2, 3, 1, 2}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.,
0., 13., 14., 0., 0., 0., 18., 19., 0., 21., 22., 23.});
test_whereFp32(
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},