ADD:concat/split operator and cuda kernels (#29)

* ADD:concat/split operator and cuda kernels

refector

minor change comment

ADD:concat/split operator and cuda kernels

merge split_kernel and concat_kernel to split_concat_kernel.

Revert "fix"

This reverts commit 459926be09a838658ec55f1e0a72b3cf17037d5c.

fix

ADD:concat/split operator and cuda kernels

change whole tensor name to composed tensor

fix some

remove unused header.

rebase

add CudaKernel

add test for split.

ADD split operator and cuda kernel.

modify test.

ADD:concat operator and cuda kernel.

ADD:concat/split operator and cuda kernels

fix some

remove unused header.

rebase

add CudaKernel

ADD:concat/split operator and cuda kernels

add test for split.

ADD split operator and cuda kernel.

modify test.

ADD:concat operator and cuda kernel.

* remove extra comment; typo fix.

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
wendy12022 2022-09-29 11:01:30 +08:00 committed by GitHub
parent 5560d0f2fb
commit 3c6e208f42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 570 additions and 21 deletions

View File

@ -186,6 +186,10 @@ class OperatorObj : public Object {
IT_ASSERT(outputs.size() == 1, "Unimplemented"); IT_ASSERT(outputs.size() == 1, "Unimplemented");
return outputs[0]; return outputs[0];
} }
Tensor getOutput(size_t i) const {
IT_ASSERT(i < outputs.size(), "Index exceeded");
return outputs.at(i);
}
OpType getOpType() const { return type; } OpType getOpType() const { return type; }
// HACK: set correct data type // HACK: set correct data type
DataType getDType() const { return getInputs(0)->getDType(); } DataType getDType() const { return getInputs(0)->getDType(); }

View File

@ -1,25 +1,6 @@
#pragma once #pragma once
#include "operators/element_wise.h"
namespace infini { namespace infini {
void div_kernel(float *a, float *b, float *c, int num); void div_kernel(float *a, float *b, float *c, int num);
void pow_kernel(float *a, float *b, float *c, int num); void pow_kernel(float *a, float *b, float *c, int num);
void element_wise_kernel(const Operator &_op) {
auto op = as<ElementWiseObj>(_op);
float *const aData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const bData = (op->getInputs(1)->getRawDataPtr<float *>());
float *const cData = (op->getOutput()->getRawDataPtr<float *>());
auto dim = op->getInputs(0)->getDims();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
if (op->getOpType() == OpType::Div)
div_kernel(aData, bData, cData, n * c * h * w);
else if (op->getOpType() == OpType::Pow)
pow_kernel(aData, bData, cData, n * c * h * w);
else
IT_TODO_HALT();
}
}; // namespace infini }; // namespace infini

View File

@ -0,0 +1,35 @@
#pragma once
#include <cstdio>
const int BATCH_SIZE = 32; // parallel tensor number.
const int DIM_MAX_SIZE = 4;
// Concat operator acts like element tensors composing to one big tensor,and
// split operator acts like one big tensor being composed by element
// tensors.
struct ElementTensorMetadata {
float *data[BATCH_SIZE];
int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in
// the composed tensor.
int dimSize[BATCH_SIZE]; // the dimention size of the element tensor.
int nElements[BATCH_SIZE]; // the number of elements of the element tensor.
void print() const {
for (int i = 0; i < BATCH_SIZE; i++)
printf("%d:(data=%p,dimBgNo=%d,dimSize=%d,nElements=%d)\n", i,
data[i], dimBgNo[i], dimSize[i], nElements[i]);
}
};
struct ComposedTensorMetadata {
int dimSize[DIM_MAX_SIZE];
int stride[DIM_MAX_SIZE];
float *data;
};
namespace infini {
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
const ComposedTensorMetadata &compMeta, int dim,
int batchSize, int nDims, bool isSplit);
} // namespace infini

View File

@ -0,0 +1,22 @@
#pragma once
#include "core/operator.h"
namespace infini {
class ConcatObj : public OperatorObj {
int dim;
public:
ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }
int getDim() const { return dim; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -15,7 +15,7 @@ class PadObj : public OperatorObj {
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
Shape PadObj::getPads() const { return pads; } Shape getPads() const { return pads; }
private: private:
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;

25
include/operators/split.h Normal file
View File

@ -0,0 +1,25 @@
#pragma once
#include "core/operator.h"
namespace infini {
class SplitObj : public OperatorObj {
int dim, num; // split dim;Average split num or outputs size
vector<int> ratio; // output dim ratio
public:
SplitObj(GraphObj *graph, Tensor input, std::optional<TensorVec> outputs,
int dim, int num);
SplitObj(GraphObj *graph, Tensor input, std::optional<TensorVec> outputs,
int dim, const vector<int> &ratio);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return num; }
int getDim() const { return dim; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -78,7 +78,19 @@ class MulCudnn : public ElementWiseCudnn {
class ElementWiseCuda : public CudaKernelWithoutConfig { class ElementWiseCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
element_wise_kernel(_op); auto op = as<ElementWiseObj>(_op);
float *const aData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const bData = (op->getInputs(1)->getRawDataPtr<float *>());
float *const cData = (op->getOutput()->getRawDataPtr<float *>());
auto dim = op->getInputs(0)->getDims();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
if (op->getOpType() == OpType::Div)
div_kernel(aData, bData, cData, n * c * h * w);
else if (op->getOpType() == OpType::Pow)
pow_kernel(aData, bData, cData, n * c * h * w);
else
IT_TODO_HALT();
} }
}; };

View File

@ -0,0 +1,81 @@
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_split_concat.h"
#include "operators/concat.h"
#include "operators/split.h"
#include <functional>
namespace infini {
void initComposedTensorMetadata(ComposedTensorMetadata &metadata,
Tensor tensor) {
int nDims = tensor->getDims().size();
auto strides = tensor->getStride();
IT_ASSERT(strides.size() == (size_t)nDims);
for (int i = 0; i < nDims; ++i) {
metadata.dimSize[i] = tensor->getDims().at(i);
metadata.stride[i] = strides.at(i);
}
metadata.data = tensor->getRawDataPtr<float *>();
}
void initElementTensorMetadata(ElementTensorMetadata &metadata,
TensorVec tensors, int idx, int dim,
int &dimBgIdx, int &batchCounter) {
int nTensors = tensors.size();
for (; batchCounter < BATCH_SIZE && idx + batchCounter < nTensors;
++batchCounter) {
auto tensor = tensors.at(idx + batchCounter);
auto dimSize = tensor->getDims()[dim];
metadata.data[batchCounter] = tensor->getRawDataPtr<float *>();
metadata.dimBgNo[batchCounter] = dimBgIdx;
metadata.dimSize[batchCounter] = dimSize;
metadata.nElements[batchCounter] = tensor->size();
dimBgIdx += dimSize;
}
}
class CudaCompute {
public:
void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim,
int nDims, bool isSplit) const {
IT_ASSERT(nDims <= DIM_MAX_SIZE);
ComposedTensorMetadata composedMetadata;
initComposedTensorMetadata(composedMetadata, composedTensor);
int dimBgNo = 0;
int nElemets = elementsTensor.size();
for (int i = 0; i < nElemets; i += BATCH_SIZE) {
ElementTensorMetadata elemMetadata;
int batchCounter = 0;
initElementTensorMetadata(elemMetadata, elementsTensor, i, dim,
dimBgNo, batchCounter);
split_concat_kernel(elemMetadata, composedMetadata, dim,
batchCounter, nDims, isSplit);
}
}
};
class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
do_compute(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(),
_op->getOutput()->getDims().size(), false);
}
};
class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
do_compute(_op->getInputs(0), _op->getOutputs(),
as<SplitObj>(_op)->getDim(),
_op->getInputs(0)->getDims().size(), true);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Concat, DataType::Float32, ConcatCuda,
"Concat_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Split, DataType::Float32, SplitCuda,
"Split_CUDA_Float32");
} // namespace infini

View File

@ -0,0 +1,71 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_split_concat.h"
int getMultiProcessorCount() {
int cur_device;
checkCudaError(cudaGetDevice(&cur_device));
struct cudaDeviceProp prop;
checkCudaError(cudaGetDeviceProperties(&prop, cur_device));
return prop.multiProcessorCount;
}
__host__ __device__ int
elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim,
int nDim, ComposedTensorMetadata wholeMeta) {
int offset = 0;
#pragma unroll
for (int i = nDim - 1; i >= 1; --i) {
int size = (i == dim) ? dimSize : wholeMeta.dimSize[i];
int p = elementIndex % size;
int oP = (i == dim) ? (p + dimBgNo) : p;
elementIndex = (elementIndex - p) / size;
offset += oP * wholeMeta.stride[i];
}
return offset + elementIndex * wholeMeta.stride[0];
}
__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
ComposedTensorMetadata compMeta, int dim,
int nDims, bool isSplit) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int nElements = elemMeta.nElements[blockIdx.y];
if (tid >= nElements)
return;
auto dimBgNo = elemMeta.dimBgNo[blockIdx.y];
auto dimSize = elemMeta.dimSize[blockIdx.y];
float *elemData = elemMeta.data[blockIdx.y];
int stride = gridDim.x * blockDim.x;
while (tid < nElements) {
int Offset =
elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta);
// copy data from input to output
// for split:input is composed tensor;for concat:input is element
// tensors.
if (isSplit)
elemData[tid] = compMeta.data[Offset];
else
compMeta.data[Offset] = elemData[tid];
tid += stride;
}
}
namespace infini {
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
const ComposedTensorMetadata &compMeta, int dim,
int batchSize, int nDims, bool isSplit) {
dim3 blockSize = dim3(32 * 16);
// y dim is number of tensors.
dim3 gridSize(getMultiProcessorCount(), batchSize);
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
isSplit);
}
} // namespace infini

58
src/operators/concat.cc Normal file
View File

@ -0,0 +1,58 @@
#include "operators/concat.h"
namespace infini {
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim)
: OperatorObj(OpType::Concat, inputs, {output}), dim(dim) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() > 1);
Shape dims = inputs[0]->getDims();
ShapeElem n = dims.at(dim);
for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) {
auto input = *itr;
auto iDims = input->getDims();
if (dims.size() != iDims.size())
return {};
int nDims = dims.size();
for (auto i = 0; i < nDims; i++) {
if (i == dim) {
n += iDims.at(i);
continue;
}
if (iDims.at(i) != dims.at(i))
return {};
}
}
dims[dim] = n;
return {{dims}};
}
std::string ConcatObj::toString() const {
std::ostringstream os;
os << "Concat[" << getGuid() << "]";
os << "(";
for (auto input : inputs)
os << vecToString(input->getDims()) << ",";
os << "dim=" << dim << ",";
os << "input=";
for (auto input : inputs)
os << input->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> ConcatObj::getWorkloadVector() const {
vector<int> ret = getOutput()->getDims();
ret.emplace(ret.begin(), (int)inputs.size());
ret.emplace(ret.begin(), dim);
ret.emplace(ret.begin(), enum_to_underlying(type));
return ret;
}
vector<int> ConcatObj::getOpAttrVector() const {
return {enum_to_underlying(type), dim};
}
} // namespace infini

89
src/operators/split.cc Normal file
View File

@ -0,0 +1,89 @@
#include "operators/split.h"
#include <numeric>
namespace infini {
SplitObj::SplitObj(GraphObj *graph, Tensor input,
std::optional<TensorVec> outputs, int dim, int num)
: OperatorObj(OpType::Split, {input},
((!outputs) ? TensorVec{nullptr} : (*outputs))),
dim(dim), num(num), ratio({}) {
int dimSize = input->getDims().at(dim);
int pieceSize = dimSize / num;
int lastSize = dimSize - pieceSize * num;
if (lastSize > 0) {
ratio = std::vector<int>(num - 1, pieceSize);
ratio.emplace_back(lastSize + pieceSize);
} else
ratio = std::vector<int>(num, pieceSize);
if (!outputs) {
TensorVec tmp(num, nullptr);
this->outputs = tmp;
}
IT_ASSERT(checkValid(graph));
}
SplitObj::SplitObj(GraphObj *graph, Tensor input,
std::optional<TensorVec> outputs, int dim,
const vector<int> &ratio)
: OperatorObj(OpType::Split, {input},
((!outputs) ? TensorVec{nullptr} : (*outputs))),
dim(dim), num(-1), ratio(ratio) {
num = ratio.size();
if (!outputs) {
TensorVec tmp(num, nullptr);
this->outputs = tmp;
}
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> SplitObj::inferShape(const TensorVec &inputs) const {
if (num == -1 || ratio.size() == 0)
return {};
auto inputDims = inputs[0]->getDims();
int totalSize = inputDims.at(dim);
int ratioSum = std::accumulate(ratio.begin(), ratio.end(), 0);
if (totalSize % ratioSum != 0)
return {};
int pieceSize = totalSize / ratioSum;
vector<Shape> ret;
Shape outShape = inputDims;
for (int i = 0; i < num; i++) {
outShape[dim] = pieceSize * ratio.at(i);
ret.push_back(outShape);
}
return {ret};
}
vector<int> SplitObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace_back(dim);
ret.emplace_back(num);
return ret;
}
vector<int> SplitObj::getOpAttrVector() const {
return {enum_to_underlying(type), dim, num};
}
string SplitObj::toString() const {
std::ostringstream os;
os << "Split[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "dim=" << dim << ",";
os << "num= " << num << ",";
os << "ratio= " << vecToString(ratio) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=";
for (auto i = 0; i < num; i++)
os << outputs[i]->getGuid() << ",";
os << ")";
return os.str();
}
} // namespace infini

View File

@ -0,0 +1,76 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/concat.h"
#include "test.h"
namespace infini {
/*
int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize,
int concatDim, int outputDimSize[4],
int outputStride[4], int nDim) {
int offset = 0;
for (int i = nDim - 1; i >= 1; --i) {
int size = (i == concatDim) ? dimSize : outputDimSize[i];
int p = linearIndex % size;
int oP = (i == concatDim) ? (p + dimBgNo) : p;
linearIndex = (linearIndex - p) / size;
offset += oP * outputStride[i];
}
return offset + linearIndex * outputStride[0];
}
TEST(Concat, OffsetTrans) {
int dimSize[] = {2, 3};
int strides[] = {3, 1};
int catDim = 1, nDim = 2;
EXPECT_EQ(inputOffset2CatOffset(0, 0, 1, catDim, dimSize, strides, nDim),
0);
EXPECT_EQ(inputOffset2CatOffset(1, 0, 1, catDim, dimSize, strides, nDim),
3);
EXPECT_EQ(inputOffset2CatOffset(0, 1, 2, catDim, dimSize, strides, nDim),
1);
EXPECT_EQ(inputOffset2CatOffset(1, 1, 2, catDim, dimSize, strides, nDim),
2);
EXPECT_EQ(inputOffset2CatOffset(2, 1, 2, catDim, dimSize, strides, nDim),
4);
EXPECT_EQ(inputOffset2CatOffset(3, 1, 2, catDim, dimSize, strides, nDim),
5);
}
*/
TEST(Concat, Cuda) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto t1 = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32);
auto t2 = gCpu->addTensor({2, 2, 1, 1}, DataType::Float32);
auto t3 = gCpu->addTensor({2, 2, 2, 1}, DataType::Float32);
gCpu->dataMalloc();
t1->setData(IncrementalGenerator());
t2->setData(OneGenerator());
t3->setData(OneGenerator());
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto op = gCuda->addOp<ConcatObj>(TensorVec{gCuda->cloneTensor(t1),
gCuda->cloneTensor(t2),
gCuda->cloneTensor(t3)},
nullptr, 2);
gCuda->dataMalloc();
cudaRuntime->run(gCuda);
// cudaPrintTensor(op->getOutput());
// copy output from CUDA to CPU
auto oCpu = gCpu->cloneTensor(op->getOutput());
EXPECT_TRUE(
oCpu->equalData(vector<float>{0, 1, 2, 1, 1, 1, 3, 4, 5, 1, 1, 1,
6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1}));
}
} // namespace infini

View File

@ -0,0 +1,40 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/split.h"
#include "test.h"
namespace infini {
TEST(Split, Cuda) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({2, 10, 2, 1}, DataType::Float32);
gCpu->dataMalloc();
input->setData(IncrementalGenerator());
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto op =
gCuda->addOp<SplitObj>(gCuda->cloneTensor(input), std::nullopt, 1, 3);
gCuda->dataMalloc();
cudaRuntime->run(gCuda);
// copy output from CUDA to CPU
EXPECT_EQ(op->getOutputs().size(), (size_t)3);
auto o0Cpu = gCpu->cloneTensor(op->getOutput(0));
auto o1Cpu = gCpu->cloneTensor(op->getOutput(1));
auto o2Cpu = gCpu->cloneTensor(op->getOutput(2));
EXPECT_TRUE(o0Cpu->equalData(
vector<float>{0, 1, 2, 3, 4, 5, 20, 21, 22, 23, 24, 25}));
EXPECT_TRUE(o1Cpu->equalData(
vector<float>{6, 7, 8, 9, 10, 11, 26, 27, 28, 29, 30, 31}));
EXPECT_TRUE(o2Cpu->equalData(vector<float>{
12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39}));
}
} // namespace infini

View File

@ -0,0 +1,17 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/concat.h"
#include "test.h"
namespace infini {
TEST(Concat, ShapeInfer) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
auto t1 = g->addTensor({1, 3, 2, 4}, DataType::Float32);
auto t2 = g->addTensor({1, 3, 2, 5}, DataType::Float32);
auto op = g->addOp<ConcatObj>(TensorVec{t1, t2}, nullptr, 3);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 9}));
}
} // namespace infini

View File

@ -0,0 +1,38 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/split.h"
#include "test.h"
namespace infini {
TEST(Split, ShapeInfer) {
{
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32);
auto op = g->addOp<SplitObj>(input, std::nullopt, 3, 4);
EXPECT_EQ(op->numOutputs(), 4);
EXPECT_EQ(op->getOutputs().size(), (size_t)4);
EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
}
{
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32);
auto op =
g->addOp<SplitObj>(input, std::nullopt, 3, vector<int>{1, 2, 2});
EXPECT_EQ(op->getOutputs().size(), (size_t)3);
EXPECT_EQ(op->numOutputs(), 3);
EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 6}));
EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 6}));
}
}
} // namespace infini