add half and float dequantizeLinear

This commit is contained in:
xgqdut2016 2023-12-18 17:47:53 +08:00
parent 03ed8c4de7
commit 9c82936386
10 changed files with 693 additions and 45 deletions

View File

@ -85,6 +85,8 @@ class GraphHandlerObj {
Tensor cast(Tensor input, Tensor output, int to); Tensor cast(Tensor input, Tensor output, int to);
Tensor expand(Tensor input, Tensor output, Shape dims); Tensor expand(Tensor input, Tensor output, Shape dims);
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output); Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
Tensor dequantizeLinear(Tensor inputX, Tensor inputScale, Tensor output,
Tensor inputZeroPoint, int axis);
std::vector<int> getDims(Tensor x) { return x->getDims(); } std::vector<int> getDims(Tensor x) { return x->getDims(); }
Tensor allReduceSum(Tensor input, Tensor output); Tensor allReduceSum(Tensor input, Tensor output);
@ -102,9 +104,6 @@ class GraphHandlerObj {
TensorVec dynamicQuantizeLinear(Tensor input, TensorVec dynamicQuantizeLinear(Tensor input,
std::optional<TensorVec> outputs); std::optional<TensorVec> outputs);
Tensor dequantizeLinear(Tensor input, Tensor scale, Tensor zero_point,
Tensor output, int axis);
//------ modifiers //------ modifiers
inline bool topo_sort() { return g->topo_sort(); } inline bool topo_sort() { return g->topo_sort(); }

View File

@ -0,0 +1,17 @@
#pragma once
#include "operators/unary.h"
namespace infini {
void DequantizeLinearKernel(const uint8_t *inputX, const float *inputScale,
float *output, const int dimsize, const int stride,
const uint8_t *inputZeroPoint, const int size);
void DequantizeLinearKernel(const uint8_t *inputX, const float *inputScale,
float *output, const int dimsize, const int stride,
const int size);
void DequantizeLinearKernel(const uint8_t *inputX, const half *inputScale,
half *output, const int dimsize, const int stride,
const uint8_t *inputZeroPoint, const int size);
void DequantizeLinearKernel(const uint8_t *inputX, const half *inputScale,
half *output, const int dimsize, const int stride,
const int size);
}; // namespace infini

View File

@ -3,9 +3,8 @@
namespace infini { namespace infini {
/** /**
* @brief The linear dequantization operator. * @brief y = (x - x_zero_point) *x_scale
* It consumes a quantized tensor, a scale, and a zero point to compute *
* the full precision tensor.
*/ */
class DequantizeLinearObj : public OperatorObj { class DequantizeLinearObj : public OperatorObj {
int axis; int axis;
@ -15,26 +14,30 @@ class DequantizeLinearObj : public OperatorObj {
* @brief Construct a new DequantizeLinear object. * @brief Construct a new DequantizeLinear object.
* *
* @param graph The computation graph that this operator belongs to. * @param graph The computation graph that this operator belongs to.
* @param input The input tensor. * @param inputX The input tensor X.
* @param scale Scale for input. * @param inputScale The input tensor x_scale.
* @param zero_point Zero point for input. * @param output The output tensor.
* @param outputs The output tensors. * @param inputZeroPoint The z_zero_point.
* @param axis The axis of the dequantizing dimension of the input tensor.
*/ */
DequantizeLinearObj(GraphObj *graph, Tensor input, Tensor scale, DequantizeLinearObj(GraphObj *graph, Tensor inputX, Tensor inputScale,
Tensor zero_pointr, Tensor output, int axis); Tensor output, Tensor inputZeroPoint = nullptr,
int axis = 1);
OP_CLONE(DequantizeLinearObj); OP_CLONE(DequantizeLinearObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
Tensor getZeroPoint() const {
return inputs.size() > 2 ? inputs[2] : nullptr;
}
int numInputs() const override { return inputs.size(); } int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
int getAxis() const { return axis; }
private: private:
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;
vector<DataType> inferDataType(const TensorVec &inputs) const override; vector<DataType> inferDataType(const TensorVec &inputs) const override;
}; };

View File

@ -866,19 +866,23 @@ class OnnxStub:
): ):
tensors[name] = tensor tensors[name] = tensor
elif node.op_type == "DequantizeLinear": elif node.op_type == "DequantizeLinear":
attributes = _parse_attribute( (inputX, inputScale) = (tensors[node.input[i]] for i in [0, 1])
node, inputZeroPoint = (
{ None if len(node.input) < 3 else tensors[node.input[2]]
"axis": 1, )
}, output = tensors.get(node.output[0])
axis = next(
(attr.i for attr in node.attribute if attr.name == "axis"),
0,
) )
axis = attributes["axis"]
tensors[node.output[0]] = self.handler.dequantizeLinear( tensors[node.output[0]] = self.handler.dequantizeLinear(
tensor[node.input[0]], inputX,
tensor[node.input[1]], inputScale,
tensor[node.input[2]] if len(node.input) > 2 else None, output,
inputZeroPoint,
axis, axis,
) )
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))
new_node_name.append(node.name) new_node_name.append(node.name)

View File

@ -6,7 +6,7 @@
#include "operators/broadcast.h" #include "operators/broadcast.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/conv.h" #include "operators/conv.h"
#include "operators/dequantize_linear.h" #include "operators/dequantizeLinear.h"
#include "operators/dynamic_quantize_linear.h" #include "operators/dynamic_quantize_linear.h"
#include "operators/element_wise.h" #include "operators/element_wise.h"
#include "operators/expand.h" #include "operators/expand.h"
@ -521,18 +521,19 @@ GraphHandlerObj::dynamicQuantizeLinear(Tensor input,
} }
} }
Tensor GraphHandlerObj::dequantizeLinear(Tensor input, Tensor scale, Tensor GraphHandlerObj::dequantizeLinear(Tensor inputX, Tensor inputScale,
Tensor zero_point, Tensor output, Tensor output, Tensor inputZeroPoint,
int axis) { int axis) {
if (output) { if (output) {
g->addOpWithOutputs<DequantizeLinearObj>( g->addOpWithOutputs<DequantizeLinearObj>(
std::move(input), std::move(scale), std::move(zero_point), output, std::move(inputX), std::move(inputScale), output,
axis); std::move(inputZeroPoint), axis);
return output; return output;
} else { } else {
return g return g
->addOp<DequantizeLinearObj>(std::move(input), std::move(scale), ->addOp<DequantizeLinearObj>(std::move(inputX),
std::move(zero_point), output, axis) std::move(inputScale), output,
std::move(inputZeroPoint), axis)
->getOutput(); ->getOutput();
} }
} }

View File

@ -517,6 +517,7 @@ void init_graph_builder(py::module &m) {
.def("expand", &Handler::expand, policy::move) .def("expand", &Handler::expand, policy::move)
.def("erf", &Handler::erf, policy::move) .def("erf", &Handler::erf, policy::move)
.def("where", &Handler::where, policy::move) .def("where", &Handler::where, policy::move)
.def("dequantizeLinear", &Handler::dequantizeLinear, policy::move)
.def("topo_sort", &Handler::topo_sort, policy::automatic) .def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("optimize", &Handler::optimize, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic)
.def("operators", &Handler::operators, policy::move) .def("operators", &Handler::operators, policy::move)

View File

@ -0,0 +1,57 @@
#include "operators/dequantizeLinear.h"
#include "cuda/cuda_dequantizeLinear.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class DequantizeLinearCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<DequantizeLinearObj>(_op);
void *const inputX = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inputScale = (op->getInputs(1)->getRawDataPtr<void *>());
void *const output = (op->getOutput()->getRawDataPtr<void *>());
const int axis = op->getAxis();
const int stride = op->getInputs(0)->getStride().at(axis);
auto dims = op->getInputs(0)->getDims();
int dimsize = dims[op->getAxis()];
int size = op->getOutput()->size();
if (op->getInputs(1)->getDType() == DataType::Float32) {
if (op->numInputs() == 3) {
void *const inputZeroPoint =
(op->getInputs(2)->getRawDataPtr<void *>());
DequantizeLinearKernel((uint8_t *)inputX, (float *)inputScale,
(float *)output, dimsize, stride,
(uint8_t *)inputZeroPoint, size);
} else {
DequantizeLinearKernel((uint8_t *)inputX, (float *)inputScale,
(float *)output, dimsize, stride, size);
}
} else if (op->getInputs(1)->getDType() == DataType::Float16) {
if (op->numInputs() == 3) {
void *const inputZeroPoint =
(op->getInputs(2)->getRawDataPtr<void *>());
DequantizeLinearKernel((uint8_t *)inputX, (half *)inputScale,
(half *)output, dimsize, stride,
(uint8_t *)inputZeroPoint, size);
} else {
DequantizeLinearKernel((uint8_t *)inputX, (half *)inputScale,
(half *)output, dimsize, stride, size);
}
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::DequantizeLinear, DequantizeLinearCuda,
"DequantizeLinear_CUDA");
}; // namespace infini

View File

@ -0,0 +1,345 @@
#include "cuda/cuda_common.h"
#include <cub/cub.cuh>
template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__
void blockDequantizeLinearKernel(const uint8_t *inputX, const T *inputScale,
T *output, const int dimsize,
const int stride,
const uint8_t *inputZeroPoint) {
// len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp;
int remain = dimsize % BLOCK_DIM;
int step = (dimsize - remain) / BLOCK_DIM + 1;
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
output[tid + (threadIdx.x * step + ind) * stride] =
static_cast<T>(
inputX[tid + (threadIdx.x * step + ind) * stride] -
inputZeroPoint[threadIdx.x * step + ind]) *
inputScale[threadIdx.x * step + ind];
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
output[tid +
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
stride] =
static_cast<T>(
inputX[tid + (remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride] -
inputZeroPoint[remain * step +
(threadIdx.x - remain) * (step - 1) + ind]) *
inputScale[remain * step + (threadIdx.x - remain) * (step - 1) +
ind];
}
}
}
template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__
void blockDequantizeLinearKernel(const uint8_t *inputX, const T *inputScale,
T *output, const int dimsize,
const int stride) {
// len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp;
int remain = dimsize % BLOCK_DIM;
int step = (dimsize - remain) / BLOCK_DIM + 1;
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
output[tid + (threadIdx.x * step + ind) * stride] =
static_cast<T>(
inputX[tid + (threadIdx.x * step + ind) * stride]) *
inputScale[threadIdx.x * step + ind];
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
output[tid +
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
stride] =
static_cast<T>(
inputX[tid + (remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride]) *
inputScale[remain * step + (threadIdx.x - remain) * (step - 1) +
ind];
}
}
}
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void
warpDequantizeLinearKernel(const uint8_t *inputX, const T *inputScale,
T *output, const int dimsize, const int otherSize,
const int stride, const uint8_t *inputZeroPoint) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
int remain = dimsize % BLOCK_DIM_x;
int step = (dimsize - remain) / BLOCK_DIM_x + 1;
if (otherIdx < otherSize) {
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
output[tid + (threadIdx.x * step + ind) * stride] =
static_cast<T>(
inputX[tid + (threadIdx.x * step + ind) * stride] -
inputZeroPoint[threadIdx.x * step + ind]) *
inputScale[threadIdx.x * step + ind];
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
output[tid + (remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride] =
static_cast<T>(
inputX[tid +
(remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride] -
inputZeroPoint[remain * step +
(threadIdx.x - remain) * (step - 1) +
ind]) *
inputScale[remain * step +
(threadIdx.x - remain) * (step - 1) + ind];
}
}
}
}
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void
warpDequantizeLinearKernel(const uint8_t *inputX, const T *inputScale,
T *output, const int dimsize, const int otherSize,
const int stride) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
int remain = dimsize % BLOCK_DIM_x;
int step = (dimsize - remain) / BLOCK_DIM_x + 1;
if (otherIdx < otherSize) {
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
output[tid + (threadIdx.x * step + ind) * stride] =
static_cast<T>(
inputX[tid + (threadIdx.x * step + ind) * stride]) *
inputScale[threadIdx.x * step + ind];
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
output[tid + (remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride] =
static_cast<T>(
inputX[tid +
(remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride]) *
inputScale[remain * step +
(threadIdx.x - remain) * (step - 1) + ind];
}
}
}
}
namespace infini {
void DequantizeLinearKernel(const uint8_t *inputX, const float *inputScale,
float *output, const int dimsize, const int stride,
const uint8_t *inputZeroPoint, const int size) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockDequantizeLinearKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
inputX, inputScale, output, dimsize, stride, inputZeroPoint);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<float, 32, 32>
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
num_block, stride, inputZeroPoint);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<float, 16, 64>
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
num_block, stride, inputZeroPoint);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<float, 8, 128>
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
num_block, stride, inputZeroPoint);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<float, 4, 256>
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
num_block, stride, inputZeroPoint);
}
}
void DequantizeLinearKernel(const uint8_t *inputX, const float *inputScale,
float *output, const int dimsize, const int stride,
const int size) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockDequantizeLinearKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
inputX, inputScale, output, dimsize, stride);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<float, 32, 32><<<grid_dim, block_dim>>>(
inputX, inputScale, output, dimsize, num_block, stride);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<float, 16, 64><<<grid_dim, block_dim>>>(
inputX, inputScale, output, dimsize, num_block, stride);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<float, 8, 128><<<grid_dim, block_dim>>>(
inputX, inputScale, output, dimsize, num_block, stride);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<float, 4, 256><<<grid_dim, block_dim>>>(
inputX, inputScale, output, dimsize, num_block, stride);
}
}
//-------------
void DequantizeLinearKernel(const uint8_t *inputX, const half *inputScale,
half *output, const int dimsize, const int stride,
const uint8_t *inputZeroPoint, const int size) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockDequantizeLinearKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
inputX, inputScale, output, dimsize, stride, inputZeroPoint);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<half, 32, 32>
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
num_block, stride, inputZeroPoint);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<half, 16, 64>
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
num_block, stride, inputZeroPoint);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<half, 8, 128>
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
num_block, stride, inputZeroPoint);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<half, 4, 256>
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
num_block, stride, inputZeroPoint);
}
}
void DequantizeLinearKernel(const uint8_t *inputX, const half *inputScale,
half *output, const int dimsize, const int stride,
const int size) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockDequantizeLinearKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
inputX, inputScale, output, dimsize, stride);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<half, 32, 32><<<grid_dim, block_dim>>>(
inputX, inputScale, output, dimsize, num_block, stride);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<half, 16, 64><<<grid_dim, block_dim>>>(
inputX, inputScale, output, dimsize, num_block, stride);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<half, 8, 128><<<grid_dim, block_dim>>>(
inputX, inputScale, output, dimsize, num_block, stride);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpDequantizeLinearKernel<half, 4, 256><<<grid_dim, block_dim>>>(
inputX, inputScale, output, dimsize, num_block, stride);
}
}
} // namespace infini

View File

@ -1,13 +1,15 @@
#include "operators/dequantize_linear.h" #include "operators/dequantizeLinear.h"
#include "utils/operator_utils.h" #include "utils/operator_utils.h"
namespace infini { namespace infini {
DequantizeLinearObj::DequantizeLinearObj(GraphObj *graph, Tensor input,
Tensor scale, Tensor zero_point, DequantizeLinearObj::DequantizeLinearObj(GraphObj *graph, Tensor inputX,
Tensor output, int axis) Tensor inputScale, Tensor output,
[[maybe_unused]] Tensor inputZeroPoint,
int axis)
: OperatorObj(OpType::DequantizeLinear, : OperatorObj(OpType::DequantizeLinear,
zero_point ? TensorVec{input, scale, zero_point} inputZeroPoint ? TensorVec{inputX, inputScale, inputZeroPoint}
: TensorVec{input, scale}, : TensorVec{inputX, inputScale},
{output}), {output}),
axis(axis) { axis(axis) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
@ -15,13 +17,14 @@ DequantizeLinearObj::DequantizeLinearObj(GraphObj *graph, Tensor input,
optional<vector<Shape>> optional<vector<Shape>>
DequantizeLinearObj::inferShape(const TensorVec &inputs) { DequantizeLinearObj::inferShape(const TensorVec &inputs) {
return {{inputs[0]->getDims()}}; return {{inputs[0]->getDims()}}; // x.shape = output.shape = inputs[0].shape
} }
vector<DataType> vector<DataType>
DequantizeLinearObj::inferDataType(const TensorVec &inputs) const { DequantizeLinearObj::inferDataType(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3); IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
return {inputs[1]->getDType()};
return {
inputs[1]->getDType()}; // scale.dtype = output.dtype = inputs[1].dtype
} }
std::string DequantizeLinearObj::toString() const { std::string DequantizeLinearObj::toString() const {
@ -29,12 +32,11 @@ std::string DequantizeLinearObj::toString() const {
os << "DequantizeLinear[" << getGuid() << "]"; os << "DequantizeLinear[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "inputX=" << inputs[0]->getGuid() << ",";
os << "scale=" << inputs[1]->getGuid() << ","; os << "inputScale=" << inputs[1]->getGuid() << ",";
// os << "inputZeroPoint=" << inputs[2]->getGuid() << ",";
os << "axis=" << axis << ","; os << "axis=" << axis << ",";
os << "output="; os << "output=" << outputs[0]->getGuid() << ")";
for (auto output : outputs)
os << output->getGuid() << ",";
return os.str(); return os.str();
} }
@ -45,7 +47,7 @@ vector<int> DequantizeLinearObj::getWorkloadVector() const {
} }
vector<int> DequantizeLinearObj::getOpAttrVector() const { vector<int> DequantizeLinearObj::getOpAttrVector() const {
return {type.underlying()}; return {type.underlying(), axis};
} }
} // namespace infini } // namespace infini

View File

@ -0,0 +1,219 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/dequantizeLinear.h"
#include "test.h"
namespace infini {
void test_dequantizeLinearFp32(
const Shape &inputXShape, const vector<uint8_t> &inputXData,
const Shape &inputScaleShape, const vector<float> &inputScaleData, int axis,
const vector<float> &ExpectData,
const std::optional<Shape> &zeroPointShape = std::nullopt,
const std::optional<std::vector<uint8_t>> &inputZeroPointData =
std::nullopt) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
if (zeroPointShape.has_value() && inputZeroPointData.has_value()) {
Shape inputZeroPointShape = *zeroPointShape;
auto inputZeroPoint =
gCpu->addTensor(inputZeroPointShape, DataType::UInt8);
auto inputX = gCpu->addTensor(inputXShape, DataType::UInt8);
auto inputScale = gCpu->addTensor(inputScaleShape, DataType::Float32);
gCpu->dataMalloc();
inputZeroPoint->copyin(*inputZeroPointData); //
inputX->copyin(inputXData);
inputScale->copyin(inputScaleData); //
// inputX->printData();
// inputZeroPoint->printData();
// inputScale->printData();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputZeroPointGpu = gCuda->cloneTensor(inputZeroPoint);
auto inputXGpu = gCuda->cloneTensor(inputX);
auto inputScaleGpu = gCuda->cloneTensor(inputScale);
auto op = gCuda->addOp<DequantizeLinearObj>(
inputXGpu, inputScaleGpu, nullptr, inputZeroPointGpu,
axis); // DequantizeLinearObj
gCuda->dataMalloc();
inputZeroPointGpu->copyin(*inputZeroPointData);
// gCpu->cloneTensor(inputZeroPointGpu)->printData();
inputXGpu->copyin(inputXData);
inputScaleGpu->copyin(inputScaleData);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
} else {
auto inputX = gCpu->addTensor(inputXShape, DataType::UInt8);
auto inputScale = gCpu->addTensor(inputScaleShape, DataType::Float32);
gCpu->dataMalloc();
inputX->copyin(inputXData);
inputScale->copyin(inputScaleData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputXGpu = gCuda->cloneTensor(inputX);
auto inputScaleGpu = gCuda->cloneTensor(inputScale);
auto op = gCuda->addOp<DequantizeLinearObj>(
inputXGpu, inputScaleGpu, nullptr, nullptr,
axis); // DequantizeLinearObj
gCuda->dataMalloc();
inputXGpu->copyin(inputXData);
inputScaleGpu->copyin(inputScaleData);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
}
void test_dequantizeLinearFp16(
const Shape &inputXShape, const vector<uint8_t> &inputXData,
const Shape &inputScaleShape,
const std::function<void(void *, size_t, DataType)> &generator, int axis,
const vector<float> &ExpectData,
const std::optional<Shape> &zeroPointShape = std::nullopt,
const std::optional<std::vector<uint8_t>> &inputZeroPointData =
std::nullopt) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
if (zeroPointShape.has_value() && inputZeroPointData.has_value()) {
Shape inputZeroPointShape = *zeroPointShape;
auto inputZeroPoint =
gCpu->addTensor(inputZeroPointShape, DataType::UInt8);
auto inputX = gCpu->addTensor(inputXShape, DataType::UInt8);
auto inputScale = gCpu->addTensor(inputScaleShape, DataType::Float16);
gCpu->dataMalloc();
inputZeroPoint->copyin(*inputZeroPointData); //
// inputZeroPoint->printData();
inputX->copyin(inputXData);
inputScale->setData(generator);
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputZeroPointGpu = gCuda->cloneTensor(inputZeroPoint);
auto inputXGpu = gCuda->cloneTensor(inputX);
auto inputScaleGpu = gCuda->cloneTensor(inputScale);
// gCpu->cloneTensor(inputZeroPointGpu)->printData();
auto op = gCuda->addOp<DequantizeLinearObj>(
inputXGpu, inputScaleGpu, nullptr, inputZeroPointGpu,
axis); // DequantizeLinearObj
gCuda->dataMalloc();
inputZeroPointGpu->copyin(*inputZeroPointData);
// gCpu->cloneTensor(inputZeroPointGpu)->printData();
inputXGpu->copyin(inputXData);
inputScaleGpu->setData(generator);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
} else {
auto inputX = gCpu->addTensor(inputXShape, DataType::UInt8);
auto inputScale = gCpu->addTensor(inputScaleShape, DataType::Float16);
gCpu->dataMalloc();
inputX->copyin(inputXData);
inputScale->setData(generator);
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputXGpu = gCuda->cloneTensor(inputX);
auto inputScaleGpu = gCuda->cloneTensor(inputScale);
auto op = gCuda->addOp<DequantizeLinearObj>(
inputXGpu, inputScaleGpu, nullptr, nullptr,
axis); // DequantizeLinearObj
gCuda->dataMalloc();
inputXGpu->copyin(inputXData);
inputScaleGpu->setData(generator);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
}
TEST(CUDA_DequantizeLinearFp32, run) {
test_dequantizeLinearFp32(
Shape{2, 3, 2, 3},
vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1,
vector<float>{-0.3000000, 0.0000000, 0.3000000, 0.6000000,
0.9000000, 1.2000000, 0.8000000, 1.0000000,
1.2000000, 1.4000000, 1.6000000, 1.8000001,
4.5000000, 5.0000000, 5.5000000, 6.0000000,
6.5000000, 7.0000000, 5.1000004, 5.4000001,
5.7000003, 6.0000000, 6.3000002, 6.6000004,
4.4000001, 4.5999999, 4.8000002, 5.0000000,
5.2000003, 5.4000001, 13.5000000, 14.0000000,
14.5000000, 15.0000000, 15.5000000, 16.0000000},
Shape{3}, vector<uint8_t>{1, 2, 3});
test_dequantizeLinearFp32(
Shape{2, 3, 2, 3},
vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1,
vector<float>{0.0000000, 0.3000000, 0.6000000, 0.9000000,
1.2000000, 1.5000000, 1.2000000, 1.4000000,
1.6000000, 1.8000001, 2.0000000, 2.2000000,
6.0000000, 6.5000000, 7.0000000, 7.5000000,
8.0000000, 8.5000000, 5.4000001, 5.7000003,
6.0000000, 6.3000002, 6.6000004, 6.9000001,
4.8000002, 5.0000000, 5.2000003, 5.4000001,
5.5999999, 5.8000002, 15.0000000, 15.5000000,
16.0000000, 16.5000000, 17.0000000, 17.5000000});
} // python output
TEST(CUDA_DequantizeLinearFp16, run) {
test_dequantizeLinearFp16(
Shape{2, 3, 2, 3},
vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
Shape{3}, ValGenerator<2>(), 1,
vector<float>{-2., 0., 2., 4., 6., 8., 8., 10., 12.,
14., 16., 18., 18., 20., 22., 24., 26., 28.,
34., 36., 38., 40., 42., 44., 44., 46., 48.,
50., 52., 54., 54., 56., 58., 60., 62., 64.},
Shape{3}, vector<uint8_t>{1, 2, 3});
test_dequantizeLinearFp16(
Shape{2, 3, 2, 3},
vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
Shape{3}, ValGenerator<2>(), 1,
vector<float>{0., 2., 4., 6., 8., 10., 12., 14., 16.,
18., 20., 22., 24., 26., 28., 30., 32., 34.,
36., 38., 40., 42., 44., 46., 48., 50., 52.,
54., 56., 58., 60., 62., 64., 66., 68., 70.});
} // python output
} // namespace infini