forked from jiuyuan/InfiniTensor
add split_concat fp16
This commit is contained in:
parent
fda0a5f982
commit
dd4a90fb5e
|
@ -8,8 +8,8 @@ const int DIM_MAX_SIZE = 8;
|
|||
// Concat operator acts like element tensors composing to one big tensor,and
|
||||
// split operator acts like one big tensor being composed by element
|
||||
// tensors.
|
||||
struct ElementTensorMetadata {
|
||||
float *data[BATCH_SIZE];
|
||||
template <typename T> struct ElementTensorMetadata {
|
||||
T *data[BATCH_SIZE];
|
||||
int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in
|
||||
// the composed tensor.
|
||||
int dimSize[BATCH_SIZE]; // the dimention size of the element tensor.
|
||||
|
@ -20,16 +20,17 @@ struct ElementTensorMetadata {
|
|||
data[i], dimBgNo[i], dimSize[i], nElements[i]);
|
||||
}
|
||||
};
|
||||
|
||||
struct ComposedTensorMetadata {
|
||||
template <typename T> struct ComposedTensorMetadata {
|
||||
int dimSize[DIM_MAX_SIZE];
|
||||
int stride[DIM_MAX_SIZE];
|
||||
float *data;
|
||||
T *data;
|
||||
};
|
||||
|
||||
namespace infini {
|
||||
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
||||
const ComposedTensorMetadata &compMeta, int dim,
|
||||
void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
|
||||
const ComposedTensorMetadata<float> &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit);
|
||||
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
|
||||
const ComposedTensorMetadata<half> &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
namespace infini {
|
||||
|
||||
class CudaCompute {
|
||||
void initComposedTensorMetadata(ComposedTensorMetadata &metadata,
|
||||
template <typename T>
|
||||
void initComposedTensorMetadata(ComposedTensorMetadata<T> &metadata,
|
||||
Tensor tensor) const {
|
||||
int nDims = tensor->getRank();
|
||||
auto strides = tensor->getStride();
|
||||
|
@ -16,10 +17,10 @@ class CudaCompute {
|
|||
metadata.dimSize[i] = tensor->getDims().at(i);
|
||||
metadata.stride[i] = strides.at(i);
|
||||
}
|
||||
metadata.data = tensor->getRawDataPtr<float *>();
|
||||
metadata.data = tensor->getRawDataPtr<T *>();
|
||||
}
|
||||
|
||||
void initElementTensorMetadata(ElementTensorMetadata &metadata,
|
||||
template <typename T>
|
||||
void initElementTensorMetadata(ElementTensorMetadata<T> &metadata,
|
||||
TensorVec tensors, int idx, int dim,
|
||||
int &dimBgIdx, int &batchCounter) const {
|
||||
int nTensors = tensors.size();
|
||||
|
@ -27,7 +28,7 @@ class CudaCompute {
|
|||
++batchCounter) {
|
||||
auto tensor = tensors.at(idx + batchCounter);
|
||||
auto dimSize = tensor->getDims()[dim];
|
||||
metadata.data[batchCounter] = tensor->getRawDataPtr<float *>();
|
||||
metadata.data[batchCounter] = tensor->getRawDataPtr<T *>();
|
||||
metadata.dimBgNo[batchCounter] = dimBgIdx;
|
||||
metadata.dimSize[batchCounter] = dimSize;
|
||||
metadata.nElements[batchCounter] = tensor->size();
|
||||
|
@ -36,17 +37,17 @@ class CudaCompute {
|
|||
}
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim,
|
||||
int nDims, bool isSplit) const {
|
||||
IT_ASSERT(nDims <= DIM_MAX_SIZE);
|
||||
|
||||
ComposedTensorMetadata composedMetadata;
|
||||
initComposedTensorMetadata(composedMetadata, composedTensor);
|
||||
ComposedTensorMetadata<T> composedMetadata;
|
||||
initComposedTensorMetadata<T>(composedMetadata, composedTensor);
|
||||
|
||||
int dimBgNo = 0;
|
||||
int nElemets = elementsTensor.size();
|
||||
for (int i = 0; i < nElemets; i += BATCH_SIZE) {
|
||||
ElementTensorMetadata elemMetadata;
|
||||
ElementTensorMetadata<T> elemMetadata;
|
||||
int batchCounter = 0;
|
||||
initElementTensorMetadata(elemMetadata, elementsTensor, i, dim,
|
||||
dimBgNo, batchCounter);
|
||||
|
@ -74,18 +75,30 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
|
|||
}
|
||||
}
|
||||
}
|
||||
do_compute(_op->getOutput(), _op->getInputs(),
|
||||
as<ConcatObj>(_op)->getDim(), _op->getOutput()->getRank(),
|
||||
false);
|
||||
if (_op->getDType() == DataType::Float32) {
|
||||
do_compute<float>(_op->getOutput(), _op->getInputs(),
|
||||
as<ConcatObj>(_op)->getDim(),
|
||||
_op->getOutput()->getRank(), false);
|
||||
} else if (_op->getDType() == DataType::Float16) {
|
||||
do_compute<half>(_op->getOutput(), _op->getInputs(),
|
||||
as<ConcatObj>(_op)->getDim(),
|
||||
_op->getOutput()->getRank(), false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
do_compute(_op->getInputs(0), _op->getOutputs(),
|
||||
as<SplitObj>(_op)->getDim(), _op->getInputs(0)->getRank(),
|
||||
true);
|
||||
if (_op->getDType() == DataType::Float32) {
|
||||
do_compute<float>(_op->getInputs(0), _op->getOutputs(),
|
||||
as<SplitObj>(_op)->getDim(),
|
||||
_op->getInputs(0)->getRank(), true);
|
||||
} else if (_op->getDType() == DataType::Float16) {
|
||||
do_compute<half>(_op->getInputs(0), _op->getOutputs(),
|
||||
as<SplitObj>(_op)->getDim(),
|
||||
_op->getInputs(0)->getRank(), true);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_split_concat.h"
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ int
|
||||
elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim,
|
||||
int nDim, ComposedTensorMetadata wholeMeta) {
|
||||
int nDim, ComposedTensorMetadata<T> wholeMeta) {
|
||||
int offset = 0;
|
||||
|
||||
// COMP(x0,...,xk,...,xn-1) = ELMT[xk / d](x0,...,xk % d,...xn-1)
|
||||
|
@ -25,10 +25,10 @@ elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim,
|
|||
int oP = (dim == 0) ? (elementIndex + dimBgNo) : elementIndex;
|
||||
return offset + oP * wholeMeta.stride[0];
|
||||
}
|
||||
|
||||
__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
||||
ComposedTensorMetadata compMeta, int dim,
|
||||
int nDims, bool isSplit) {
|
||||
template <typename T>
|
||||
__global__ void _split_concat_kernel(ElementTensorMetadata<T> elemMeta,
|
||||
ComposedTensorMetadata<T> compMeta,
|
||||
int dim, int nDims, bool isSplit) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int nElements = elemMeta.nElements[blockIdx.y];
|
||||
if (tid >= nElements)
|
||||
|
@ -36,10 +36,10 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
|||
|
||||
auto dimBgNo = elemMeta.dimBgNo[blockIdx.y];
|
||||
auto dimSize = elemMeta.dimSize[blockIdx.y];
|
||||
float *elemData = elemMeta.data[blockIdx.y];
|
||||
T *elemData = elemMeta.data[blockIdx.y];
|
||||
|
||||
int Offset =
|
||||
elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta);
|
||||
elementIdx2ComposedIdx<T>(tid, dimBgNo, dimSize, dim, nDims, compMeta);
|
||||
// copy data from input to output
|
||||
// for split:input is composed tensor;for concat:input is element
|
||||
// tensors.
|
||||
|
@ -52,8 +52,22 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
|||
namespace infini {
|
||||
|
||||
// TODO: when dim=0, the operation can be executed in-place
|
||||
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
||||
const ComposedTensorMetadata &compMeta, int dim,
|
||||
void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
|
||||
const ComposedTensorMetadata<float> &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit) {
|
||||
dim3 blockSize = dim3(32 * 16);
|
||||
// gridsize = max_n_elements / blockSize
|
||||
int max_n_elements =
|
||||
*std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize);
|
||||
int gridDimX = (max_n_elements - 1) / (32 * 16) + 1;
|
||||
// each y is a split among the batch
|
||||
dim3 gridSize(gridDimX, batchSize);
|
||||
|
||||
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
|
||||
isSplit);
|
||||
}
|
||||
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
|
||||
const ComposedTensorMetadata<half> &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit) {
|
||||
dim3 blockSize = dim3(32 * 16);
|
||||
// gridsize = max_n_elements / blockSize
|
||||
|
|
|
@ -187,4 +187,42 @@ TEST(ConcatToIdentity, Cuda) {
|
|||
EXPECT_TRUE(
|
||||
oCpu->equalData(vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}));
|
||||
}
|
||||
//----------
|
||||
TEST(ConcatFp16, CudaHigh) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto t1 = gCpu->addTensor({2, 2, 3, 1, 2}, DataType::Float16);
|
||||
auto t2 = gCpu->addTensor({2, 2, 1, 1, 2}, DataType::Float16);
|
||||
auto t3 = gCpu->addTensor({2, 2, 2, 1, 2}, DataType::Float16);
|
||||
gCpu->dataMalloc();
|
||||
t1->setData(ValGenerator<2>());
|
||||
t2->setData(ValGenerator<1>());
|
||||
t3->setData(ValGenerator<4>());
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto t1Gpu = gCuda->cloneTensor(t1);
|
||||
auto t2Gpu = gCuda->cloneTensor(t2);
|
||||
auto t3Gpu = gCuda->cloneTensor(t3);
|
||||
|
||||
auto op =
|
||||
gCuda->addOp<ConcatObj>(TensorVec{t1Gpu, t2Gpu, t3Gpu}, nullptr, 2);
|
||||
gCuda->dataMalloc();
|
||||
t1Gpu->setData(ValGenerator<2>());
|
||||
t2Gpu->setData(ValGenerator<1>());
|
||||
t3Gpu->setData(ValGenerator<4>());
|
||||
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// cudaPrintTensor(op->getOutput());
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
2., 2., 2., 2., 2., 2., 1., 1., 4., 4., 4., 4., 2., 2., 2., 2.,
|
||||
2., 2., 1., 1., 4., 4., 4., 4., 2., 2., 2., 2., 2., 2., 1., 1.,
|
||||
4., 4., 4., 4., 2., 2., 2., 2., 2., 2., 1., 1., 4., 4., 4., 4.}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -98,5 +98,35 @@ TEST(Split, Cuda_dim0) {
|
|||
EXPECT_TRUE(o0Cpu->equalData(vector<float>{0, 1, 2}));
|
||||
EXPECT_TRUE(o1Cpu->equalData(vector<float>{3, 4, 5}));
|
||||
}
|
||||
//----------------
|
||||
TEST(SplitFp16, CudaHigh) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float16);
|
||||
gCpu->dataMalloc();
|
||||
input->setData(ValGenerator<2>());
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto inputGpu = gCuda->cloneTensor(input);
|
||||
auto op = gCuda->addOp<SplitObj>(inputGpu, std::nullopt, 1, 3);
|
||||
gCuda->dataMalloc();
|
||||
inputGpu->setData(ValGenerator<2>());
|
||||
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
EXPECT_EQ(op->getOutputs().size(), (size_t)3);
|
||||
auto o0Cpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
auto o1Cpu = gCpu->cloneTensor(op->getOutput(1));
|
||||
auto o2Cpu = gCpu->cloneTensor(op->getOutput(2));
|
||||
EXPECT_TRUE(o0Cpu->equalData(vector<float>{
|
||||
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.}));
|
||||
EXPECT_TRUE(o1Cpu->equalData(vector<float>{
|
||||
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.}));
|
||||
EXPECT_TRUE(o2Cpu->equalData(vector<float>{
|
||||
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.}));
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue