修复split concat当dim=0结果出错的问题 (#138)

Fix split_concat kernel not supporting dim=0

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
PanZezhong1725 2023-09-25 10:25:54 +08:00 committed by GitHub
parent 8f2597a508
commit 62be816f53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 28 deletions

View File

@ -1,30 +1,29 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/cuda_split_concat.h" #include "cuda/cuda_split_concat.h"
int getMultiProcessorCount() {
int cur_device;
checkCudaError(cudaGetDevice(&cur_device));
struct cudaDeviceProp prop;
checkCudaError(cudaGetDeviceProperties(&prop, cur_device));
return prop.multiProcessorCount;
}
__host__ __device__ int __host__ __device__ int
elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim, elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim,
int nDim, ComposedTensorMetadata wholeMeta) { int nDim, ComposedTensorMetadata wholeMeta) {
int offset = 0; int offset = 0;
// COMP(x0,...,xk,...,xn-1) = ELMT[xk / d](x0,...,xk % d,...xn-1)
// where k=dim, n=ndim, d=dimSize is the splited length of
// dimension dim
#pragma unroll #pragma unroll
// Interate through n-1 to 1
for (int i = nDim - 1; i >= 1; --i) { for (int i = nDim - 1; i >= 1; --i) {
int size = (i == dim) ? dimSize : wholeMeta.dimSize[i]; int size = (i == dim) ? dimSize : wholeMeta.dimSize[i];
int p = elementIndex % size; int p = elementIndex % size;
// dimBgNo move the pointer to correct location in composed data
// corresponding to current element, with repect to the splitted
// dimension dim
int oP = (i == dim) ? (p + dimBgNo) : p; int oP = (i == dim) ? (p + dimBgNo) : p;
elementIndex = (elementIndex - p) / size; elementIndex = (elementIndex - p) / size;
offset += oP * wholeMeta.stride[i]; offset += oP * wholeMeta.stride[i];
} }
// Deal with i = 0
return offset + elementIndex * wholeMeta.stride[0]; int oP = (dim == 0) ? (elementIndex + dimBgNo) : elementIndex;
return offset + oP * wholeMeta.stride[0];
} }
__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
@ -38,31 +37,29 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
auto dimBgNo = elemMeta.dimBgNo[blockIdx.y]; auto dimBgNo = elemMeta.dimBgNo[blockIdx.y];
auto dimSize = elemMeta.dimSize[blockIdx.y]; auto dimSize = elemMeta.dimSize[blockIdx.y];
float *elemData = elemMeta.data[blockIdx.y]; float *elemData = elemMeta.data[blockIdx.y];
int stride = gridDim.x * blockDim.x;
while (tid < nElements) { int Offset =
int Offset = elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta);
elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta); // copy data from input to output
// copy data from input to output // for split:input is composed tensor;for concat:input is element
// for split:input is composed tensor;for concat:input is element // tensors.
// tensors. if (isSplit)
if (isSplit) elemData[tid] = compMeta.data[Offset];
elemData[tid] = compMeta.data[Offset]; else
else compMeta.data[Offset] = elemData[tid];
compMeta.data[Offset] = elemData[tid];
tid += stride;
}
} }
namespace infini { namespace infini {
// TODO: when dim=0, the operation can be executed in-place
void split_concat_kernel(const ElementTensorMetadata &eleMeta, void split_concat_kernel(const ElementTensorMetadata &eleMeta,
const ComposedTensorMetadata &compMeta, int dim, const ComposedTensorMetadata &compMeta, int dim,
int batchSize, int nDims, bool isSplit) { int batchSize, int nDims, bool isSplit) {
dim3 blockSize = dim3(32 * 16); dim3 blockSize = dim3(32 * 16);
// gridsize =n_elements / blockSize
// y dim is number of tensors. int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1;
dim3 gridSize(getMultiProcessorCount(), batchSize); // each y is a split among the batch
dim3 gridSize(gridDimX, batchSize);
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims, _split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
isSplit); isSplit);

View File

@ -8,6 +8,7 @@
namespace infini { namespace infini {
/* /*
// Test cuda splitted idx to complosed idx in cpu. Uncomment to run this test.
int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize, int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize,
int concatDim, int outputDimSize[4], int concatDim, int outputDimSize[4],
int outputStride[4], int nDim) { int outputStride[4], int nDim) {
@ -22,7 +23,8 @@ int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize,
offset += oP * outputStride[i]; offset += oP * outputStride[i];
} }
return offset + linearIndex * outputStride[0]; int oP = (concatDim == 0) ? (linearIndex + dimBgNo) : linearIndex;
return offset + oP * outputStride[0];
} }
TEST(Concat, OffsetTrans) { TEST(Concat, OffsetTrans) {
@ -41,8 +43,22 @@ TEST(Concat, OffsetTrans) {
4); 4);
EXPECT_EQ(inputOffset2CatOffset(3, 1, 2, catDim, dimSize, strides, nDim), EXPECT_EQ(inputOffset2CatOffset(3, 1, 2, catDim, dimSize, strides, nDim),
5); 5);
catDim = 0;
EXPECT_EQ(inputOffset2CatOffset(0, 0, 3, catDim, dimSize, strides, nDim),
0);
EXPECT_EQ(inputOffset2CatOffset(1, 0, 3, catDim, dimSize, strides, nDim),
1);
EXPECT_EQ(inputOffset2CatOffset(2, 0, 3, catDim, dimSize, strides, nDim),
2);
EXPECT_EQ(inputOffset2CatOffset(0, 1, 3, catDim, dimSize, strides, nDim),
3);
EXPECT_EQ(inputOffset2CatOffset(1, 1, 3, catDim, dimSize, strides, nDim),
4);
EXPECT_EQ(inputOffset2CatOffset(2, 1, 3, catDim, dimSize, strides, nDim),
5);
} }
*/ */
TEST(Concat, Cuda) { TEST(Concat, Cuda) {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime); Graph gCpu = make_ref<GraphObj>(runtime);
@ -78,4 +94,32 @@ TEST(Concat, Cuda) {
6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1})); 6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1}));
} }
TEST(Concat, Cuda_dim0) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto t1 = gCpu->addTensor({1, 3}, DataType::Float32);
auto t2 = gCpu->addTensor({1, 3}, DataType::Float32);
auto t3 = gCpu->addTensor({1, 3}, DataType::Float32);
gCpu->dataMalloc();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto t1Gpu = gCuda->cloneTensor(t1);
auto t2Gpu = gCuda->cloneTensor(t2);
auto t3Gpu = gCuda->cloneTensor(t3);
auto op =
gCuda->addOp<ConcatObj>(TensorVec{t1Gpu, t2Gpu, t3Gpu}, nullptr, 0);
gCuda->dataMalloc();
t1Gpu->setData(IncrementalGenerator()); // 0 1 2
t2Gpu->setData(OneGenerator()); // 1 1 1
t3Gpu->setData(IncrementalGenerator()); // 0 1 2
cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput());
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 1, 2, 1, 1, 1, 0, 1, 2}));
}
} // namespace infini } // namespace infini

View File

@ -39,4 +39,30 @@ TEST(Split, Cuda) {
12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39})); 12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39}));
} }
TEST(Split, Cuda_dim0) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({2, 3}, DataType::Float32);
gCpu->dataMalloc();
input->setData(IncrementalGenerator());
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = gCuda->cloneTensor(input);
auto op = gCuda->addOp<SplitObj>(inputGpu, std::nullopt, 0, 2);
gCuda->dataMalloc();
inputGpu->setData(IncrementalGenerator());
cudaRuntime->run(gCuda);
// copy output from CUDA to CPU
EXPECT_EQ(op->getOutputs().size(), (size_t)2);
auto o0Cpu = gCpu->cloneTensor(op->getOutput(0));
auto o1Cpu = gCpu->cloneTensor(op->getOutput(1));
EXPECT_TRUE(o0Cpu->equalData(vector<float>{0, 1, 2}));
EXPECT_TRUE(o1Cpu->equalData(vector<float>{3, 4, 5}));
}
} // namespace infini } // namespace infini