add split_concat fp16

This commit is contained in:
xgqdut2016 2023-12-11 16:45:16 +08:00
parent fda0a5f982
commit dd4a90fb5e
5 changed files with 129 additions and 33 deletions

View File

@ -8,8 +8,8 @@ const int DIM_MAX_SIZE = 8;
// Concat operator acts like element tensors composing to one big tensor,and
// split operator acts like one big tensor being composed by element
// tensors.
struct ElementTensorMetadata {
float *data[BATCH_SIZE];
template <typename T> struct ElementTensorMetadata {
T *data[BATCH_SIZE];
int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in
// the composed tensor.
int dimSize[BATCH_SIZE]; // the dimention size of the element tensor.
@ -20,16 +20,17 @@ struct ElementTensorMetadata {
data[i], dimBgNo[i], dimSize[i], nElements[i]);
}
};
struct ComposedTensorMetadata {
template <typename T> struct ComposedTensorMetadata {
int dimSize[DIM_MAX_SIZE];
int stride[DIM_MAX_SIZE];
float *data;
T *data;
};
namespace infini {
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
const ComposedTensorMetadata &compMeta, int dim,
void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
const ComposedTensorMetadata<float> &compMeta, int dim,
int batchSize, int nDims, bool isSplit);
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
const ComposedTensorMetadata<half> &compMeta, int dim,
int batchSize, int nDims, bool isSplit);
} // namespace infini

View File

@ -7,7 +7,8 @@
namespace infini {
class CudaCompute {
void initComposedTensorMetadata(ComposedTensorMetadata &metadata,
template <typename T>
void initComposedTensorMetadata(ComposedTensorMetadata<T> &metadata,
Tensor tensor) const {
int nDims = tensor->getRank();
auto strides = tensor->getStride();
@ -16,10 +17,10 @@ class CudaCompute {
metadata.dimSize[i] = tensor->getDims().at(i);
metadata.stride[i] = strides.at(i);
}
metadata.data = tensor->getRawDataPtr<float *>();
metadata.data = tensor->getRawDataPtr<T *>();
}
void initElementTensorMetadata(ElementTensorMetadata &metadata,
template <typename T>
void initElementTensorMetadata(ElementTensorMetadata<T> &metadata,
TensorVec tensors, int idx, int dim,
int &dimBgIdx, int &batchCounter) const {
int nTensors = tensors.size();
@ -27,7 +28,7 @@ class CudaCompute {
++batchCounter) {
auto tensor = tensors.at(idx + batchCounter);
auto dimSize = tensor->getDims()[dim];
metadata.data[batchCounter] = tensor->getRawDataPtr<float *>();
metadata.data[batchCounter] = tensor->getRawDataPtr<T *>();
metadata.dimBgNo[batchCounter] = dimBgIdx;
metadata.dimSize[batchCounter] = dimSize;
metadata.nElements[batchCounter] = tensor->size();
@ -36,17 +37,17 @@ class CudaCompute {
}
public:
template <typename T>
void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim,
int nDims, bool isSplit) const {
IT_ASSERT(nDims <= DIM_MAX_SIZE);
ComposedTensorMetadata composedMetadata;
initComposedTensorMetadata(composedMetadata, composedTensor);
ComposedTensorMetadata<T> composedMetadata;
initComposedTensorMetadata<T>(composedMetadata, composedTensor);
int dimBgNo = 0;
int nElemets = elementsTensor.size();
for (int i = 0; i < nElemets; i += BATCH_SIZE) {
ElementTensorMetadata elemMetadata;
ElementTensorMetadata<T> elemMetadata;
int batchCounter = 0;
initElementTensorMetadata(elemMetadata, elementsTensor, i, dim,
dimBgNo, batchCounter);
@ -74,18 +75,30 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
}
}
}
do_compute(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(), _op->getOutput()->getRank(),
false);
if (_op->getDType() == DataType::Float32) {
do_compute<float>(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(),
_op->getOutput()->getRank(), false);
} else if (_op->getDType() == DataType::Float16) {
do_compute<half>(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(),
_op->getOutput()->getRank(), false);
}
}
};
class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
do_compute(_op->getInputs(0), _op->getOutputs(),
as<SplitObj>(_op)->getDim(), _op->getInputs(0)->getRank(),
true);
if (_op->getDType() == DataType::Float32) {
do_compute<float>(_op->getInputs(0), _op->getOutputs(),
as<SplitObj>(_op)->getDim(),
_op->getInputs(0)->getRank(), true);
} else if (_op->getDType() == DataType::Float16) {
do_compute<half>(_op->getInputs(0), _op->getOutputs(),
as<SplitObj>(_op)->getDim(),
_op->getInputs(0)->getRank(), true);
}
}
};

View File

@ -1,9 +1,9 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_split_concat.h"
template <typename T>
__host__ __device__ int
elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim,
int nDim, ComposedTensorMetadata wholeMeta) {
int nDim, ComposedTensorMetadata<T> wholeMeta) {
int offset = 0;
// COMP(x0,...,xk,...,xn-1) = ELMT[xk / d](x0,...,xk % d,...xn-1)
@ -25,10 +25,10 @@ elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim,
int oP = (dim == 0) ? (elementIndex + dimBgNo) : elementIndex;
return offset + oP * wholeMeta.stride[0];
}
__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
ComposedTensorMetadata compMeta, int dim,
int nDims, bool isSplit) {
template <typename T>
__global__ void _split_concat_kernel(ElementTensorMetadata<T> elemMeta,
ComposedTensorMetadata<T> compMeta,
int dim, int nDims, bool isSplit) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int nElements = elemMeta.nElements[blockIdx.y];
if (tid >= nElements)
@ -36,10 +36,10 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
auto dimBgNo = elemMeta.dimBgNo[blockIdx.y];
auto dimSize = elemMeta.dimSize[blockIdx.y];
float *elemData = elemMeta.data[blockIdx.y];
T *elemData = elemMeta.data[blockIdx.y];
int Offset =
elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta);
elementIdx2ComposedIdx<T>(tid, dimBgNo, dimSize, dim, nDims, compMeta);
// copy data from input to output
// for split:input is composed tensor;for concat:input is element
// tensors.
@ -52,8 +52,22 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
namespace infini {
// TODO: when dim=0, the operation can be executed in-place
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
const ComposedTensorMetadata &compMeta, int dim,
void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
const ComposedTensorMetadata<float> &compMeta, int dim,
int batchSize, int nDims, bool isSplit) {
dim3 blockSize = dim3(32 * 16);
// gridsize = max_n_elements / blockSize
int max_n_elements =
*std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize);
int gridDimX = (max_n_elements - 1) / (32 * 16) + 1;
// each y is a split among the batch
dim3 gridSize(gridDimX, batchSize);
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
isSplit);
}
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
const ComposedTensorMetadata<half> &compMeta, int dim,
int batchSize, int nDims, bool isSplit) {
dim3 blockSize = dim3(32 * 16);
// gridsize = max_n_elements / blockSize

View File

@ -187,4 +187,42 @@ TEST(ConcatToIdentity, Cuda) {
EXPECT_TRUE(
oCpu->equalData(vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}));
}
//----------
TEST(ConcatFp16, CudaHigh) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto t1 = gCpu->addTensor({2, 2, 3, 1, 2}, DataType::Float16);
auto t2 = gCpu->addTensor({2, 2, 1, 1, 2}, DataType::Float16);
auto t3 = gCpu->addTensor({2, 2, 2, 1, 2}, DataType::Float16);
gCpu->dataMalloc();
t1->setData(ValGenerator<2>());
t2->setData(ValGenerator<1>());
t3->setData(ValGenerator<4>());
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto t1Gpu = gCuda->cloneTensor(t1);
auto t2Gpu = gCuda->cloneTensor(t2);
auto t3Gpu = gCuda->cloneTensor(t3);
auto op =
gCuda->addOp<ConcatObj>(TensorVec{t1Gpu, t2Gpu, t3Gpu}, nullptr, 2);
gCuda->dataMalloc();
t1Gpu->setData(ValGenerator<2>());
t2Gpu->setData(ValGenerator<1>());
t3Gpu->setData(ValGenerator<4>());
cudaRuntime->run(gCuda);
// cudaPrintTensor(op->getOutput());
// copy output from CUDA to CPU
auto oCpu = gCpu->cloneTensor(op->getOutput());
EXPECT_TRUE(oCpu->equalData(vector<float>{
2., 2., 2., 2., 2., 2., 1., 1., 4., 4., 4., 4., 2., 2., 2., 2.,
2., 2., 1., 1., 4., 4., 4., 4., 2., 2., 2., 2., 2., 2., 1., 1.,
4., 4., 4., 4., 2., 2., 2., 2., 2., 2., 1., 1., 4., 4., 4., 4.}));
}
} // namespace infini

View File

@ -98,5 +98,35 @@ TEST(Split, Cuda_dim0) {
EXPECT_TRUE(o0Cpu->equalData(vector<float>{0, 1, 2}));
EXPECT_TRUE(o1Cpu->equalData(vector<float>{3, 4, 5}));
}
//----------------
TEST(SplitFp16, CudaHigh) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float16);
gCpu->dataMalloc();
input->setData(ValGenerator<2>());
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = gCuda->cloneTensor(input);
auto op = gCuda->addOp<SplitObj>(inputGpu, std::nullopt, 1, 3);
gCuda->dataMalloc();
inputGpu->setData(ValGenerator<2>());
cudaRuntime->run(gCuda);
// copy output from CUDA to CPU
EXPECT_EQ(op->getOutputs().size(), (size_t)3);
auto o0Cpu = gCpu->cloneTensor(op->getOutput(0));
auto o1Cpu = gCpu->cloneTensor(op->getOutput(1));
auto o2Cpu = gCpu->cloneTensor(op->getOutput(2));
EXPECT_TRUE(o0Cpu->equalData(vector<float>{
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.}));
EXPECT_TRUE(o1Cpu->equalData(vector<float>{
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.}));
EXPECT_TRUE(o2Cpu->equalData(vector<float>{
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.}));
}
} // namespace infini