Add: efficient CUDA transpose for last two dims

This commit is contained in:
Liyan Zheng 2023-05-05 15:16:07 +08:00
parent 6a70555892
commit abcfa76fb5
7 changed files with 243 additions and 17 deletions

View File

@ -18,6 +18,9 @@ class CudaRuntimeObj : public RuntimeObj {
bool cudaGraphStatus; // Whether CUDA graph stream capture is enabled
// CUDA device properties
cudaDeviceProp deviceProperties;
public:
CudaRuntimeObj();
virtual ~CudaRuntimeObj();
@ -54,6 +57,10 @@ class CudaRuntimeObj : public RuntimeObj {
IT_ASSERT(size <= workspaceSize);
return workspace;
}
pair<int, int> getComputeCapacitiy() const {
return {deviceProperties.major, deviceProperties.minor};
}
int getNumSMs() const { return deviceProperties.multiProcessorCount; }
void copyBlobFromCPU(void *dst, const void *src,
size_t bytes) const override {

View File

@ -10,4 +10,7 @@ void transpose_kernel(float *input, float *output, int nDims, int size,
vector<int> _dims_in, vector<int> _dims_out,
vector<int> _perms);
void invoke_transpose_last_two_dim(float *ptrA, float *ptrB, int dim0, int dim1,
int dim2, int numSMs);
} // namespace infini

View File

@ -646,7 +646,6 @@ class OnnxStub:
else:
name = f"weight{self.count_in}_{tensor.guid()}"
shape = tensor.shape()
print('shape=', shape)
data = np.random.randn(*shape)
self.initializers.append(
make_tensor(name, TensorProto.FLOAT, shape, data)

View File

@ -21,6 +21,8 @@ CudaRuntimeObj::CudaRuntimeObj()
checkCublasError(cublasSetStream(cublas, stream));
workspaceSize = 2ll << 30; // 2 GB
workspace = alloc(workspaceSize);
// Get CUDA device properties
checkCudaError(cudaGetDeviceProperties(&deviceProperties, 0));
}
CudaRuntimeObj::~CudaRuntimeObj() {

View File

@ -6,10 +6,8 @@
namespace infini {
class TransposeCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<TransposeObj>(_op);
void generic_transpose(const Ref<TransposeObj> &op,
const RuntimeObj *context) const {
auto input = op->getInputs(0);
auto output = op->getOutput();
void *const inputData = input->getRawDataPtr<void *>();
@ -42,6 +40,28 @@ class TransposeCuda : public CudaKernelWithoutConfig {
strides, outputDims, input->getDims(),
output->getDims(), perm);
}
void fast_transpose_last_dim(const Ref<TransposeObj> &op,
const RuntimeObj *context) const {
// Perm 0 2 3 1
auto cuda = dynamic_cast<const CudaRuntimeObj *>(context);
auto shape = op->getOutput()->getDims();
invoke_transpose_last_two_dim(
op->getInputs(0)->getRawDataPtr<float *>(),
op->getOutput()->getRawDataPtr<float *>(), shape[0],
shape[1] * shape[2], shape[3], cuda->getNumSMs());
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<TransposeObj>(_op);
const auto &perm = op->getPermute();
if (perm == vector{0, 2, 3, 1}) {
fast_transpose_last_dim(op, _context);
} else {
generic_transpose(op, _context);
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Transpose, DataType::Float32,

View File

@ -0,0 +1,194 @@
#include "cuda/cuda_common.h"
#include <assert.h>
#include <vector>
template <int numSM, int numWarp>
__global__ void kernel_transpose_last(float *ptrA, float *ptrB, int dim0,
int dim1, int dim2) {
int laneId = threadIdx.x % 32;
int warpId = blockIdx.x * numWarp + threadIdx.x / 32;
int n1 = (dim1 + 31) / 32;
int n2 = (dim2 + 31) / 32;
float bufA[32];
for (int i = warpId; i < dim0 * n1 * n2; i += numSM * numWarp) {
// clock_t ck0 = clock();
int i0 = i / (n1 * n2);
int i1 = (i % (n1 * n2)) / n2;
int i2 = (i % (n1 * n2)) % n2;
int offsetA = i0 * dim1 * dim2 + i2 * 32 * dim1 + i1 * 32;
int offsetB = i0 * dim1 * dim2 + i1 * 32 * dim2 + i2 * 32;
int ld1 = min(32, dim1 - i1 * 32);
int ld2 = min(32, dim2 - i2 * 32);
// if (i == 4 && laneId == 0)
// printf("%d %d\n", ld1, ld2);
if (ld2 == 32) {
#pragma unroll
for (int i = 0; i < 32; i++) {
if ((laneId + i) % 32 < ld1) {
bufA[i] = ptrA[offsetA + i * dim1 + (laneId + i) % 32];
}
}
} else if (ld2 == 17) {
#pragma unroll
for (int i = 0; i < 17; i++) {
if ((laneId + i) % 32 < ld1) {
bufA[i] = ptrA[offsetA + i * dim1 + (laneId + i) % 32];
}
}
} else if (ld2 == 4) {
#pragma unroll
for (int i = 0; i < 4; i++) {
if ((laneId + i) % 32 < ld1) {
bufA[i] = ptrA[offsetA + i * dim1 + (laneId + i) % 32];
}
}
} else {
for (int i = 0; i < ld2; i++) {
if ((laneId + i) % 32 < ld1) {
bufA[i] = ptrA[offsetA + i * dim1 + (laneId + i) % 32];
}
}
};
if (ld1 == 32) {
#pragma unroll
for (int i = 0; i < 32; i++) {
if ((i + 32 - laneId) % 32 < ld2) {
ptrB[offsetB + i * dim2 + (i + 32 - laneId) % 32] =
bufA[(i + 32 - laneId) % 32];
}
}
} else if (ld1 == 17) {
#pragma unroll
for (int i = 0; i < 17; i++) {
if ((i + 32 - laneId) % 32 < ld2) {
ptrB[offsetB + i * dim2 + (i + 32 - laneId) % 32] =
bufA[(i + 32 - laneId) % 32];
}
}
} else if (ld1 == 4) {
#pragma unroll
for (int i = 0; i < 4; i++) {
if ((i + 32 - laneId) % 32 < ld2) {
ptrB[offsetB + i * dim2 + (i + 32 - laneId) % 32] =
bufA[(i + 32 - laneId) % 32];
}
}
} else {
for (int i = 0; i < ld1; i++) {
if ((i + 32 - laneId) % 32 < ld2) {
ptrB[offsetB + i * dim2 + (i + 32 - laneId) % 32] =
bufA[(i + 32 - laneId) % 32];
}
}
};
}
}
namespace infini {
/// @brief
/// @param ptrA Input tensor of shape [dim0, dim2, dim1]
/// @param ptrB Output tensor of shape [dim0, dim1, dim2]
/// @param dim0
/// @param dim1
/// @param dim2
void invoke_transpose_last_two_dim(float *ptrA, float *ptrB, int dim0, int dim1,
int dim2, int numSMs) {
constexpr int numWarps = 4;
dim3 gridDim(numSMs, 1);
dim3 blockDim(numWarps * 32, 1);
if (numSMs == 80) { // V100
kernel_transpose_last<80, numWarps>
<<<gridDim, blockDim>>>(ptrA, ptrB, dim0, dim1, dim2);
} else if (numSMs == 108) { // A100
kernel_transpose_last<108, numWarps>
<<<gridDim, blockDim>>>(ptrA, ptrB, dim0, dim1, dim2);
} else {
IT_TODO_HALT_MSG(std::string("transpose_last_two_dim with ") +
std::to_string(numSMs) + " SMs is not implemented");
}
// cudaCheckError();
}
} // namespace infini
// constexpr int numWarm = 128, numEval = 128;
//
// void eval_transpose_last(const std::vector<int> &shape) {
// assert(shape.size() == 3);
// int size = shape[0] * shape[1] * shape[2];
// float *dataA, *dataB;
// dataA = (float *)malloc(size * sizeof(float));
// dataB = (float *)malloc(size * sizeof(float));
// for (int i0 = 0; i0 < shape[0]; i0++) {
// for (int i2 = 0; i2 < shape[2]; i2++) {
// for (int i1 = 0; i1 < shape[1]; i1++) {
// dataA[i0 * shape[1] * shape[2] + i2 * shape[1] + i1] =
// i0 * shape[1] * shape[2] + i2 * shape[1] + i1;
// }
// }
// }
// float *ptrA, *ptrB;
// checkCudaError(cudaMalloc(&ptrA, size * sizeof(float)));
// checkCudaError(cudaMalloc(&ptrB, size * sizeof(float)));
// checkCudaError(
// cudaMemcpy(ptrA, dataA, size * sizeof(float),
// cudaMemcpyHostToDevice));
// invoke_transpose_last_two_dim(ptrA, ptrB, shape[0], shape[1], shape[2]);
// checkCudaError(
// cudaMemcpy(dataB, ptrB, size * sizeof(float),
// cudaMemcpyDeviceToHost));
// for (int i0 = 0; i0 < shape[0]; i0++) {
// for (int i1 = 0; i1 < shape[1]; i1++) {
// for (int i2 = 0; i2 < shape[2]; i2++) {
// if (dataA[i0 * shape[1] * shape[2] + i1 + i2 * shape[1]] !=
// dataB[i0 * shape[1] * shape[2] + i1 * shape[2] + i2]) {
// std::cout
// << i0 << " " << i1 << " " << i2 << " "
// << dataA[i0 * shape[1] * shape[2] + i1 + i2 *
// shape[1]]
// << " "
// << dataB[i0 * shape[1] * shape[2] + i1 * shape[2] +
// i2]
// << std::endl;
// exit(-1);
// }
// }
// }
// }
// cudaEvent_t st, ed;
// checkCudaError(cudaEventCreate(&st));
// checkCudaError(cudaEventCreate(&ed));
// for (int i = 0; i < numWarm; i++) {
// invoke_transpose_last_two_dim(ptrA, ptrB, shape[0], shape[1],
// shape[2]);
// }
// checkCudaError(cudaEventRecord(st));
// for (int i = 0; i < numEval; i++) {
// invoke_transpose_last_two_dim(ptrA, ptrB, shape[0], shape[1],
// shape[2]);
// }
// checkCudaError(cudaEventRecord(ed));
// checkCudaError(cudaEventSynchronize(st));
// checkCudaError(cudaEventSynchronize(ed));
// float time;
// checkCudaError(cudaEventElapsedTime(&time, st, ed));
// float bandwidth = size * 2 * sizeof(float) * numEval / time / 1e6;
// std::cout << "transpose_last: " << shape[0] << " " << shape[1] << " "
// << shape[2] << " time: " << time / numEval
// << " ms. bandwidth: " << bandwidth << " GB/s" << std::endl;
// }
// Performance evaluation
// int main() {
// eval_transpose_last({16, 1024, 256});
// eval_transpose_last({16, 14 * 14, 1024});
// eval_transpose_last({16, 7 * 7, 2048});
// eval_transpose_last({16, 7 * 7, 128});
// eval_transpose_last({1, 14 * 14, 1024});
// eval_transpose_last({1, 7 * 7, 2048});
// eval_transpose_last({1, 7 * 7, 128});
// }

View File

@ -8,10 +8,9 @@
namespace infini {
template <class T>
void testTranspose(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
const Shape &shape, const Shape &permute, vector<float> ans) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
@ -24,22 +23,24 @@ void testTranspose(
// GPU
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
vector<int> permute = {0, 2, 1, 3};
auto gpuOp = cudaGraph->addOp<T>(inputGpu, nullptr, permute);
auto gpuOp = cudaGraph->addOp<TransposeObj>(inputGpu, nullptr, permute);
cudaGraph->dataMalloc();
cudaRuntime->run(cudaGraph);
auto outputGpu = gpuOp->getOutput();
auto oCpu = outputGpu->clone(cpuRuntime);
// Check
// inputCpu->printData();
// oCpu->printData();
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 1, 2, 3, 12, 13, 14, 15,
4, 5, 6, 7, 16, 17, 18, 19,
8, 9, 10, 11, 20, 21, 22, 23}));
EXPECT_TRUE(oCpu->equalData(ans));
}
TEST(cuda_Transpose, run) {
testTranspose<TransposeObj>(IncrementalGenerator(), Shape{1, 2, 3, 4});
TEST(cuda_Transpose, run_generic) {
testTranspose(IncrementalGenerator(), {1, 2, 3, 4}, {0, 2, 1, 3},
{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7,
16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23});
}
TEST(cuda_Transpose, run_fast_last_dim) {
testTranspose(IncrementalGenerator(), {1, 2, 3, 4}, {0, 2, 3, 1},
{0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17,
6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23});
}
} // namespace infini