forked from jiuyuan/InfiniTensor
feat: add support for dynamic_quantize_linear
This commit is contained in:
parent
0e75f99e7e
commit
8ae5958b29
|
@ -1,5 +1,5 @@
|
|||
#pragma once
|
||||
#include "operators/unary.h"
|
||||
#include "operators/dequantize_linear.h"
|
||||
|
||||
namespace infini {
|
||||
void DequantizeLinearKernel(const uint8_t *inputX, const float *inputScale,
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
#pragma once
|
||||
#include "operators/dynamic_quantize_linear.h"
|
||||
|
||||
namespace infini {
|
||||
void dynamicQuantizeLinearKernel(float *input, uint8_t *outputY, float *yScale,
|
||||
uint8_t *yZeroPoint, int size);
|
||||
}; // namespace infini
|
|
@ -0,0 +1,31 @@
|
|||
#include "operators/dynamic_quantize_linear.h"
|
||||
#include "cuda/cuda_dynamic_quantize_linear.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class DynamicQuantizeLinearCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<DynamicQuantizeLinearObj>(_op);
|
||||
|
||||
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
||||
void *const outputY = (op->getOutput(0)->getRawDataPtr<void *>());
|
||||
void *const outputYScale = (op->getOutput(1)->getRawDataPtr<void *>());
|
||||
void *const outputYZeroPoint =
|
||||
(op->getOutput(2)->getRawDataPtr<void *>());
|
||||
|
||||
int size = op->getInputs(0)->size();
|
||||
|
||||
dynamicQuantizeLinearKernel((float *)input, (uint8_t *)outputY,
|
||||
(float *)outputYScale,
|
||||
(uint8_t *)outputYZeroPoint, size);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::DynamicQuantizeLinear,
|
||||
DynamicQuantizeLinearCuda, "DynamicQuantizeLinear_CUDA");
|
||||
|
||||
}; // namespace infini
|
|
@ -1,170 +1,174 @@
|
|||
// #include "cuda/cuda_common.h"
|
||||
// #include <cub/cub.cuh>
|
||||
// template <int BLOCK_DIM>
|
||||
// __launch_bounds__(BLOCK_DIM) __global__
|
||||
// void _dynamicQuantizeLinearKernel(float *input, float *outputY,
|
||||
// uint8_t yScale, uint8_t yZeroPoint,
|
||||
// int size) {
|
||||
// int i = threadIdx.x + blockIdx.x * BLOCK_DIM;
|
||||
// float maxData = __FLT_MAX__;
|
||||
// float minData = -__FLT_MAX__;
|
||||
// int remain = size % BLOCK_DIM;
|
||||
// int step = (size - remain) / BLOCK_DIM + 1;
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
// if (threadIdx.x < remain) {
|
||||
// for (int ind = 0; ind < step; ind++) {
|
||||
#include "cuda/cuda_common.h"
|
||||
|
||||
// maxData = max(maxData, input[threadIdx.x * step + ind]);
|
||||
// }
|
||||
// } else {
|
||||
// for (int ind = 0; ind < step - 1; ind++) {
|
||||
__device__ float _saturate(float x) {
|
||||
return x < 0.f ? 0.f : (x > 255.0 ? 255.0 : x);
|
||||
}
|
||||
|
||||
// maxData =
|
||||
// max(maxData, input[remain * step +
|
||||
// (threadIdx.x - remain) * (step - 1) +
|
||||
// ind]);
|
||||
// }
|
||||
// }
|
||||
// if (threadIdx.x < remain) {
|
||||
// for (int ind = 0; ind < step; ind++) {
|
||||
template <class T>
|
||||
__device__ __forceinline__ static T max___(T a, T b) noexcept {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
// minData = min(minData, input[threadIdx.x * step + ind]);
|
||||
// }
|
||||
// } else {
|
||||
// for (int ind = 0; ind < step - 1; ind++) {
|
||||
template <class T>
|
||||
__device__ __forceinline__ static T min___(T a, T b) noexcept {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
// minData =
|
||||
// min(minData, input[remain * step +
|
||||
// (threadIdx.x - remain) * (step - 1) +
|
||||
// ind]);
|
||||
// }
|
||||
// }
|
||||
// typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
// __shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
// __shared__ float maxTotal;
|
||||
// float blockMax = BlockReduce(temp_storage).Reduce(maxData, cub::Max());
|
||||
template <int BLOCK_DIM>
|
||||
__launch_bounds__(BLOCK_DIM) __global__
|
||||
void _dynamicQuantizeLinearKernel(float *input, uint8_t *outputY,
|
||||
float *yScale, uint8_t *yZeroPoint,
|
||||
int size) {
|
||||
int i = threadIdx.x + blockIdx.x * BLOCK_DIM;
|
||||
float maxData = __FLT_MAX__;
|
||||
float minData = -__FLT_MAX__;
|
||||
int remain = size % BLOCK_DIM;
|
||||
int step = (size - remain) / BLOCK_DIM + 1;
|
||||
|
||||
// __shared__ float minTotal;
|
||||
// float blockMin = BlockReduce(temp_storage).Reduce(minData, cub::Min());
|
||||
// if (threadIdx.x == 0) {
|
||||
// maxTotal = blockMax;
|
||||
// minTotal = blockMin;
|
||||
// }
|
||||
// __syncthreads();
|
||||
// int qmax = 255;
|
||||
// int qmin = 0;
|
||||
// yScale = (max(0, maxTotal) - min(0, minTotal)) / (qmax - qmin);
|
||||
// intermediate_zero_point = qmin - minTotal / yScale;
|
||||
// yZeroPoint = cast(round(saturate(itermediate_zero_point)));
|
||||
// if (i < size) {
|
||||
// outputY[i] = saturate(round(input[i] / yScale) + yZeroPoint);
|
||||
// }
|
||||
// }
|
||||
// //----------
|
||||
if (threadIdx.x < remain) {
|
||||
for (int ind = 0; ind < step; ind++) {
|
||||
maxData = max___(maxData, input[threadIdx.x * step + ind]);
|
||||
}
|
||||
} else {
|
||||
for (int ind = 0; ind < step - 1; ind++) {
|
||||
maxData = max___(maxData,
|
||||
input[remain * step +
|
||||
(threadIdx.x - remain) * (step - 1) + ind]);
|
||||
}
|
||||
}
|
||||
if (threadIdx.x < remain) {
|
||||
for (int ind = 0; ind < step; ind++) {
|
||||
minData = min___(minData, input[threadIdx.x * step + ind]);
|
||||
}
|
||||
} else {
|
||||
for (int ind = 0; ind < step - 1; ind++) {
|
||||
minData = min___(minData,
|
||||
input[remain * step +
|
||||
(threadIdx.x - remain) * (step - 1) + ind]);
|
||||
}
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ float maxTotal;
|
||||
float blockMax = BlockReduce(temp_storage).Reduce(maxData, cub::Max());
|
||||
|
||||
// template <int BLOCK_DIM, int numPerThread>
|
||||
// __launch_bounds__(BLOCK_DIM) __global__
|
||||
// void _dynamicQuantizeLinearKernel(float *input, float *outputY,
|
||||
// uint8_t yScale, uint8_t yZeroPoint,
|
||||
// int size) {
|
||||
// int i = threadIdx.x + blockIdx.x * BLOCK_DIM;
|
||||
// float maxData = __FLT_MAX__;
|
||||
// float minData = -__FLT_MAX__;
|
||||
// int remain = size % BLOCK_DIM;
|
||||
// int step = (size - remain) / BLOCK_DIM + 1;
|
||||
// float dataPerThread[numPerThread];
|
||||
// if (threadIdx.x < remain) {
|
||||
// for (int ind = 0; ind < step; ind++) {
|
||||
// dataPerThread[ind] = input[threadIdx.x * step + ind];
|
||||
// maxData = max(maxData, dataPerThread[ind]);
|
||||
// }
|
||||
// } else {
|
||||
// for (int ind = 0; ind < step - 1; ind++) {
|
||||
// dataPerThread[ind] =
|
||||
// input[remain * step + (threadIdx.x - remain) * (step - 1) +
|
||||
// ind];
|
||||
// maxData = max(maxData, dataPerThread[ind]);
|
||||
// }
|
||||
// }
|
||||
// if (threadIdx.x < remain) {
|
||||
// for (int ind = 0; ind < step; ind++) {
|
||||
__shared__ float minTotal;
|
||||
float blockMin = BlockReduce(temp_storage).Reduce(minData, cub::Min());
|
||||
if (threadIdx.x == 0) {
|
||||
maxTotal = blockMax;
|
||||
minTotal = blockMin;
|
||||
}
|
||||
__syncthreads();
|
||||
int qmax = 255;
|
||||
int qmin = 0;
|
||||
yScale[0] = (max___(0.f, maxTotal) - min___(0.f, minTotal)) / (qmax - qmin);
|
||||
float intermediate_zero_point = qmin - minTotal / yScale[0];
|
||||
float _yZeroPoint = round(_saturate(intermediate_zero_point));
|
||||
yZeroPoint[0] = static_cast<uint8_t>(_yZeroPoint);
|
||||
if (i < size) {
|
||||
outputY[i] = static_cast<uint8_t>(
|
||||
_saturate(round(input[i] / yScale[0]) + _yZeroPoint));
|
||||
}
|
||||
}
|
||||
//----------
|
||||
|
||||
// minData = min(minData, dataPerThread[ind]);
|
||||
// }
|
||||
// } else {
|
||||
// for (int ind = 0; ind < step - 1; ind++) {
|
||||
template <int BLOCK_DIM, int numPerThread>
|
||||
__launch_bounds__(BLOCK_DIM) __global__
|
||||
void _dynamicQuantizeLinearKernel(float *input, uint8_t *outputY,
|
||||
float *yScale, uint8_t *yZeroPoint,
|
||||
int size) {
|
||||
int i = threadIdx.x + blockIdx.x * BLOCK_DIM;
|
||||
float maxData = __FLT_MAX__;
|
||||
float minData = -__FLT_MAX__;
|
||||
int remain = size % BLOCK_DIM;
|
||||
int step = (size - remain) / BLOCK_DIM + 1;
|
||||
float dataPerThread[numPerThread];
|
||||
if (threadIdx.x < remain) {
|
||||
for (int ind = 0; ind < step; ind++) {
|
||||
dataPerThread[ind] = input[threadIdx.x * step + ind];
|
||||
maxData = max___(maxData, dataPerThread[ind]);
|
||||
}
|
||||
} else {
|
||||
for (int ind = 0; ind < step - 1; ind++) {
|
||||
dataPerThread[ind] =
|
||||
input[remain * step + (threadIdx.x - remain) * (step - 1) +
|
||||
ind];
|
||||
maxData = max___(maxData, dataPerThread[ind]);
|
||||
}
|
||||
}
|
||||
if (threadIdx.x < remain) {
|
||||
for (int ind = 0; ind < step; ind++) {
|
||||
minData = min___(minData, dataPerThread[ind]);
|
||||
}
|
||||
} else {
|
||||
for (int ind = 0; ind < step - 1; ind++) {
|
||||
minData = min___(minData, dataPerThread[ind]);
|
||||
}
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ float maxTotal;
|
||||
float blockMax = BlockReduce(temp_storage).Reduce(maxData, cub::Max());
|
||||
|
||||
// minData = min(minData, dataPerThread[ind]);
|
||||
// }
|
||||
// }
|
||||
// typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
// __shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
// __shared__ float maxTotal;
|
||||
// float blockMax = BlockReduce(temp_storage).Reduce(maxData, cub::Max());
|
||||
__shared__ float minTotal;
|
||||
float blockMin = BlockReduce(temp_storage).Reduce(minData, cub::Min());
|
||||
if (threadIdx.x == 0) {
|
||||
maxTotal = blockMax;
|
||||
minTotal = blockMin;
|
||||
}
|
||||
__syncthreads();
|
||||
int qmax = 255;
|
||||
int qmin = 0;
|
||||
yScale[0] = (max___(0.f, maxTotal) - min___(0.f, minTotal)) / (qmax - qmin);
|
||||
float intermediate_zero_point = qmin - minTotal / yScale[0];
|
||||
float _yZeroPoint = round(_saturate(intermediate_zero_point));
|
||||
yZeroPoint[0] = static_cast<uint8_t>(_yZeroPoint);
|
||||
if (i < size) {
|
||||
outputY[i] = static_cast<uint8_t>(
|
||||
_saturate(round(input[i] / yScale[0]) + _yZeroPoint));
|
||||
}
|
||||
}
|
||||
|
||||
// __shared__ float minTotal;
|
||||
// float blockMin = BlockReduce(temp_storage).Reduce(minData, cub::Min());
|
||||
// if (threadIdx.x == 0) {
|
||||
// maxTotal = blockMax;
|
||||
// minTotal = blockMin;
|
||||
// }
|
||||
// __syncthreads();
|
||||
// int qmax = 255;
|
||||
// int qmin = 0;
|
||||
// yScale = (max(0.0, maxTotal) - min(0.0, minTotal)) / (qmax - qmin);
|
||||
// intermediate_zero_point = qmin - minTotal / yScale;
|
||||
// yZeroPoint = cast(round(saturate(itermediate_zero_point)));
|
||||
// if (i < size) {
|
||||
// outputY[i] = saturate(round(input[i] / yScale) + yZeroPoint);
|
||||
// }
|
||||
// }
|
||||
|
||||
// namespace infini {
|
||||
// void dynamicQuantizeLinearKernel(float *input, float *outputY, uint8_t
|
||||
// yScale,
|
||||
// uint8_t yZeroPoint, int size) {
|
||||
|
||||
// if (size > 1024 * 128) {
|
||||
|
||||
// int BLOCK_DIM = 1024;
|
||||
// int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
// _dynamicQuantizeLinearKernel<1024><<<num_blocks, BLOCK_DIM>>>(
|
||||
// input, outputY, yScale, yZeroPoint, size);
|
||||
// } else if (size > 1024 * 64) {
|
||||
|
||||
// int BLOCK_DIM = 1024;
|
||||
// int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
// _dynamicQuantizeLinearKernel<1024, 128><<<num_blocks, BLOCK_DIM>>>(
|
||||
// input, outputY, yScale, yZeroPoint, size);
|
||||
// } else if (size > 1024 * 32) {
|
||||
|
||||
// int BLOCK_DIM = 1024;
|
||||
// int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
// _dynamicQuantizeLinearKernel<1024, 64><<<num_blocks, BLOCK_DIM>>>(
|
||||
// input, outputY, yScale, yZeroPoint, size);
|
||||
// } else if (size > 1024 * 16) {
|
||||
|
||||
// int BLOCK_DIM = 1024;
|
||||
// int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
// _dynamicQuantizeLinearKernel<1024, 32><<<num_blocks, BLOCK_DIM>>>(
|
||||
// input, outputY, yScale, yZeroPoint, size);
|
||||
// } else if (size > 1024 * 4) {
|
||||
|
||||
// int BLOCK_DIM = 1024;
|
||||
// int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
// _dynamicQuantizeLinearKernel<1024, 16><<<num_blocks, BLOCK_DIM>>>(
|
||||
// input, outputY, yScale, yZeroPoint, size);
|
||||
// } else if (size > 1024) {
|
||||
|
||||
// int BLOCK_DIM = 1024;
|
||||
// int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
// _dynamicQuantizeLinearKernel<1024, 4><<<num_blocks, BLOCK_DIM>>>(
|
||||
// input, outputY, yScale, yZeroPoint, size);
|
||||
// } else {
|
||||
// int BLOCK_DIM = 1024;
|
||||
// int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
// _dynamicQuantizeLinearKernel<1024, 1><<<num_blocks, BLOCK_DIM>>>(
|
||||
// input, outputY, yScale, yZeroPoint, size);
|
||||
// }
|
||||
// }
|
||||
// } // namespace infini
|
||||
namespace infini {
|
||||
void dynamicQuantizeLinearKernel(float *input, uint8_t *outputY, float *yScale,
|
||||
uint8_t *yZeroPoint, int size) {
|
||||
if (size > 1024 * 128) {
|
||||
int BLOCK_DIM = 1024;
|
||||
int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
_dynamicQuantizeLinearKernel<1024><<<num_blocks, BLOCK_DIM>>>(
|
||||
input, outputY, yScale, yZeroPoint, size);
|
||||
} else if (size > 1024 * 64) {
|
||||
int BLOCK_DIM = 1024;
|
||||
int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
_dynamicQuantizeLinearKernel<1024, 128><<<num_blocks, BLOCK_DIM>>>(
|
||||
input, outputY, yScale, yZeroPoint, size);
|
||||
} else if (size > 1024 * 32) {
|
||||
int BLOCK_DIM = 1024;
|
||||
int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
_dynamicQuantizeLinearKernel<1024, 64><<<num_blocks, BLOCK_DIM>>>(
|
||||
input, outputY, yScale, yZeroPoint, size);
|
||||
} else if (size > 1024 * 16) {
|
||||
int BLOCK_DIM = 1024;
|
||||
int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
_dynamicQuantizeLinearKernel<1024, 32><<<num_blocks, BLOCK_DIM>>>(
|
||||
input, outputY, yScale, yZeroPoint, size);
|
||||
} else if (size > 1024 * 4) {
|
||||
int BLOCK_DIM = 1024;
|
||||
int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
_dynamicQuantizeLinearKernel<1024, 16><<<num_blocks, BLOCK_DIM>>>(
|
||||
input, outputY, yScale, yZeroPoint, size);
|
||||
} else if (size > 1024) {
|
||||
int BLOCK_DIM = 1024;
|
||||
int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
_dynamicQuantizeLinearKernel<1024, 4><<<num_blocks, BLOCK_DIM>>>(
|
||||
input, outputY, yScale, yZeroPoint, size);
|
||||
} else {
|
||||
int BLOCK_DIM = 1024;
|
||||
int num_blocks = (size + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
_dynamicQuantizeLinearKernel<1024, 1><<<num_blocks, BLOCK_DIM>>>(
|
||||
input, outputY, yScale, yZeroPoint, size);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue