forked from jiuyuan/InfiniTensor
修复split concat当dim=0结果出错的问题 (#138)
Fix split_concat kernel not supporting dim=0 Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
8f2597a508
commit
62be816f53
|
@ -1,30 +1,29 @@
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
#include "cuda/cuda_split_concat.h"
|
#include "cuda/cuda_split_concat.h"
|
||||||
|
|
||||||
int getMultiProcessorCount() {
|
|
||||||
int cur_device;
|
|
||||||
checkCudaError(cudaGetDevice(&cur_device));
|
|
||||||
|
|
||||||
struct cudaDeviceProp prop;
|
|
||||||
checkCudaError(cudaGetDeviceProperties(&prop, cur_device));
|
|
||||||
return prop.multiProcessorCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ int
|
__host__ __device__ int
|
||||||
elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim,
|
elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim,
|
||||||
int nDim, ComposedTensorMetadata wholeMeta) {
|
int nDim, ComposedTensorMetadata wholeMeta) {
|
||||||
int offset = 0;
|
int offset = 0;
|
||||||
|
|
||||||
|
// COMP(x0,...,xk,...,xn-1) = ELMT[xk / d](x0,...,xk % d,...xn-1)
|
||||||
|
// where k=dim, n=ndim, d=dimSize is the splited length of
|
||||||
|
// dimension dim
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
// Interate through n-1 to 1
|
||||||
for (int i = nDim - 1; i >= 1; --i) {
|
for (int i = nDim - 1; i >= 1; --i) {
|
||||||
int size = (i == dim) ? dimSize : wholeMeta.dimSize[i];
|
int size = (i == dim) ? dimSize : wholeMeta.dimSize[i];
|
||||||
int p = elementIndex % size;
|
int p = elementIndex % size;
|
||||||
|
// dimBgNo move the pointer to correct location in composed data
|
||||||
|
// corresponding to current element, with repect to the splitted
|
||||||
|
// dimension dim
|
||||||
int oP = (i == dim) ? (p + dimBgNo) : p;
|
int oP = (i == dim) ? (p + dimBgNo) : p;
|
||||||
elementIndex = (elementIndex - p) / size;
|
elementIndex = (elementIndex - p) / size;
|
||||||
offset += oP * wholeMeta.stride[i];
|
offset += oP * wholeMeta.stride[i];
|
||||||
}
|
}
|
||||||
|
// Deal with i = 0
|
||||||
return offset + elementIndex * wholeMeta.stride[0];
|
int oP = (dim == 0) ? (elementIndex + dimBgNo) : elementIndex;
|
||||||
|
return offset + oP * wholeMeta.stride[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
||||||
|
@ -38,9 +37,7 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
||||||
auto dimBgNo = elemMeta.dimBgNo[blockIdx.y];
|
auto dimBgNo = elemMeta.dimBgNo[blockIdx.y];
|
||||||
auto dimSize = elemMeta.dimSize[blockIdx.y];
|
auto dimSize = elemMeta.dimSize[blockIdx.y];
|
||||||
float *elemData = elemMeta.data[blockIdx.y];
|
float *elemData = elemMeta.data[blockIdx.y];
|
||||||
int stride = gridDim.x * blockDim.x;
|
|
||||||
|
|
||||||
while (tid < nElements) {
|
|
||||||
int Offset =
|
int Offset =
|
||||||
elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta);
|
elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta);
|
||||||
// copy data from input to output
|
// copy data from input to output
|
||||||
|
@ -50,19 +47,19 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
||||||
elemData[tid] = compMeta.data[Offset];
|
elemData[tid] = compMeta.data[Offset];
|
||||||
else
|
else
|
||||||
compMeta.data[Offset] = elemData[tid];
|
compMeta.data[Offset] = elemData[tid];
|
||||||
tid += stride;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
// TODO: when dim=0, the operation can be executed in-place
|
||||||
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
||||||
const ComposedTensorMetadata &compMeta, int dim,
|
const ComposedTensorMetadata &compMeta, int dim,
|
||||||
int batchSize, int nDims, bool isSplit) {
|
int batchSize, int nDims, bool isSplit) {
|
||||||
dim3 blockSize = dim3(32 * 16);
|
dim3 blockSize = dim3(32 * 16);
|
||||||
|
// gridsize =n_elements / blockSize
|
||||||
// y dim is number of tensors.
|
int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1;
|
||||||
dim3 gridSize(getMultiProcessorCount(), batchSize);
|
// each y is a split among the batch
|
||||||
|
dim3 gridSize(gridDimX, batchSize);
|
||||||
|
|
||||||
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
|
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
|
||||||
isSplit);
|
isSplit);
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
/*
|
/*
|
||||||
|
// Test cuda splitted idx to complosed idx in cpu. Uncomment to run this test.
|
||||||
int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize,
|
int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize,
|
||||||
int concatDim, int outputDimSize[4],
|
int concatDim, int outputDimSize[4],
|
||||||
int outputStride[4], int nDim) {
|
int outputStride[4], int nDim) {
|
||||||
|
@ -22,7 +23,8 @@ int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize,
|
||||||
offset += oP * outputStride[i];
|
offset += oP * outputStride[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
return offset + linearIndex * outputStride[0];
|
int oP = (concatDim == 0) ? (linearIndex + dimBgNo) : linearIndex;
|
||||||
|
return offset + oP * outputStride[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Concat, OffsetTrans) {
|
TEST(Concat, OffsetTrans) {
|
||||||
|
@ -41,8 +43,22 @@ TEST(Concat, OffsetTrans) {
|
||||||
4);
|
4);
|
||||||
EXPECT_EQ(inputOffset2CatOffset(3, 1, 2, catDim, dimSize, strides, nDim),
|
EXPECT_EQ(inputOffset2CatOffset(3, 1, 2, catDim, dimSize, strides, nDim),
|
||||||
5);
|
5);
|
||||||
|
catDim = 0;
|
||||||
|
EXPECT_EQ(inputOffset2CatOffset(0, 0, 3, catDim, dimSize, strides, nDim),
|
||||||
|
0);
|
||||||
|
EXPECT_EQ(inputOffset2CatOffset(1, 0, 3, catDim, dimSize, strides, nDim),
|
||||||
|
1);
|
||||||
|
EXPECT_EQ(inputOffset2CatOffset(2, 0, 3, catDim, dimSize, strides, nDim),
|
||||||
|
2);
|
||||||
|
EXPECT_EQ(inputOffset2CatOffset(0, 1, 3, catDim, dimSize, strides, nDim),
|
||||||
|
3);
|
||||||
|
EXPECT_EQ(inputOffset2CatOffset(1, 1, 3, catDim, dimSize, strides, nDim),
|
||||||
|
4);
|
||||||
|
EXPECT_EQ(inputOffset2CatOffset(2, 1, 3, catDim, dimSize, strides, nDim),
|
||||||
|
5);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
TEST(Concat, Cuda) {
|
TEST(Concat, Cuda) {
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
@ -78,4 +94,32 @@ TEST(Concat, Cuda) {
|
||||||
6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1}));
|
6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Concat, Cuda_dim0) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
auto t1 = gCpu->addTensor({1, 3}, DataType::Float32);
|
||||||
|
auto t2 = gCpu->addTensor({1, 3}, DataType::Float32);
|
||||||
|
auto t3 = gCpu->addTensor({1, 3}, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto t1Gpu = gCuda->cloneTensor(t1);
|
||||||
|
auto t2Gpu = gCuda->cloneTensor(t2);
|
||||||
|
auto t3Gpu = gCuda->cloneTensor(t3);
|
||||||
|
|
||||||
|
auto op =
|
||||||
|
gCuda->addOp<ConcatObj>(TensorVec{t1Gpu, t2Gpu, t3Gpu}, nullptr, 0);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
t1Gpu->setData(IncrementalGenerator()); // 0 1 2
|
||||||
|
t2Gpu->setData(OneGenerator()); // 1 1 1
|
||||||
|
t3Gpu->setData(IncrementalGenerator()); // 0 1 2
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||||
|
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 1, 2, 1, 1, 1, 0, 1, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -39,4 +39,30 @@ TEST(Split, Cuda) {
|
||||||
12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39}));
|
12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Split, Cuda_dim0) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
auto input = gCpu->addTensor({2, 3}, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto inputGpu = gCuda->cloneTensor(input);
|
||||||
|
auto op = gCuda->addOp<SplitObj>(inputGpu, std::nullopt, 0, 2);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputGpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
EXPECT_EQ(op->getOutputs().size(), (size_t)2);
|
||||||
|
auto o0Cpu = gCpu->cloneTensor(op->getOutput(0));
|
||||||
|
auto o1Cpu = gCpu->cloneTensor(op->getOutput(1));
|
||||||
|
EXPECT_TRUE(o0Cpu->equalData(vector<float>{0, 1, 2}));
|
||||||
|
EXPECT_TRUE(o1Cpu->equalData(vector<float>{3, 4, 5}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue