forked from jiuyuan/InfiniTensor
add half and float dequantizeLinear
This commit is contained in:
parent
03ed8c4de7
commit
9c82936386
|
@ -85,6 +85,8 @@ class GraphHandlerObj {
|
||||||
Tensor cast(Tensor input, Tensor output, int to);
|
Tensor cast(Tensor input, Tensor output, int to);
|
||||||
Tensor expand(Tensor input, Tensor output, Shape dims);
|
Tensor expand(Tensor input, Tensor output, Shape dims);
|
||||||
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
||||||
|
Tensor dequantizeLinear(Tensor inputX, Tensor inputScale, Tensor output,
|
||||||
|
Tensor inputZeroPoint, int axis);
|
||||||
std::vector<int> getDims(Tensor x) { return x->getDims(); }
|
std::vector<int> getDims(Tensor x) { return x->getDims(); }
|
||||||
|
|
||||||
Tensor allReduceSum(Tensor input, Tensor output);
|
Tensor allReduceSum(Tensor input, Tensor output);
|
||||||
|
@ -102,9 +104,6 @@ class GraphHandlerObj {
|
||||||
TensorVec dynamicQuantizeLinear(Tensor input,
|
TensorVec dynamicQuantizeLinear(Tensor input,
|
||||||
std::optional<TensorVec> outputs);
|
std::optional<TensorVec> outputs);
|
||||||
|
|
||||||
Tensor dequantizeLinear(Tensor input, Tensor scale, Tensor zero_point,
|
|
||||||
Tensor output, int axis);
|
|
||||||
|
|
||||||
//------ modifiers
|
//------ modifiers
|
||||||
|
|
||||||
inline bool topo_sort() { return g->topo_sort(); }
|
inline bool topo_sort() { return g->topo_sort(); }
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
#pragma once
|
||||||
|
#include "operators/unary.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void DequantizeLinearKernel(const uint8_t *inputX, const float *inputScale,
|
||||||
|
float *output, const int dimsize, const int stride,
|
||||||
|
const uint8_t *inputZeroPoint, const int size);
|
||||||
|
void DequantizeLinearKernel(const uint8_t *inputX, const float *inputScale,
|
||||||
|
float *output, const int dimsize, const int stride,
|
||||||
|
const int size);
|
||||||
|
void DequantizeLinearKernel(const uint8_t *inputX, const half *inputScale,
|
||||||
|
half *output, const int dimsize, const int stride,
|
||||||
|
const uint8_t *inputZeroPoint, const int size);
|
||||||
|
void DequantizeLinearKernel(const uint8_t *inputX, const half *inputScale,
|
||||||
|
half *output, const int dimsize, const int stride,
|
||||||
|
const int size);
|
||||||
|
}; // namespace infini
|
|
@ -3,9 +3,8 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
/**
|
/**
|
||||||
* @brief The linear dequantization operator.
|
* @brief y = (x - x_zero_point) *x_scale
|
||||||
* It consumes a quantized tensor, a scale, and a zero point to compute
|
*
|
||||||
* the full precision tensor.
|
|
||||||
*/
|
*/
|
||||||
class DequantizeLinearObj : public OperatorObj {
|
class DequantizeLinearObj : public OperatorObj {
|
||||||
int axis;
|
int axis;
|
||||||
|
@ -15,26 +14,30 @@ class DequantizeLinearObj : public OperatorObj {
|
||||||
* @brief Construct a new DequantizeLinear object.
|
* @brief Construct a new DequantizeLinear object.
|
||||||
*
|
*
|
||||||
* @param graph The computation graph that this operator belongs to.
|
* @param graph The computation graph that this operator belongs to.
|
||||||
* @param input The input tensor.
|
* @param inputX The input tensor X.
|
||||||
* @param scale Scale for input.
|
* @param inputScale The input tensor x_scale.
|
||||||
* @param zero_point Zero point for input.
|
* @param output The output tensor.
|
||||||
* @param outputs The output tensors.
|
* @param inputZeroPoint The z_zero_point.
|
||||||
* @param axis The axis of the dequantizing dimension of the input tensor.
|
|
||||||
*/
|
*/
|
||||||
DequantizeLinearObj(GraphObj *graph, Tensor input, Tensor scale,
|
DequantizeLinearObj(GraphObj *graph, Tensor inputX, Tensor inputScale,
|
||||||
Tensor zero_pointr, Tensor output, int axis);
|
Tensor output, Tensor inputZeroPoint = nullptr,
|
||||||
|
int axis = 1);
|
||||||
OP_CLONE(DequantizeLinearObj);
|
OP_CLONE(DequantizeLinearObj);
|
||||||
|
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
|
||||||
std::string toString() const override;
|
std::string toString() const override;
|
||||||
|
|
||||||
|
Tensor getZeroPoint() const {
|
||||||
|
return inputs.size() > 2 ? inputs[2] : nullptr;
|
||||||
|
}
|
||||||
int numInputs() const override { return inputs.size(); }
|
int numInputs() const override { return inputs.size(); }
|
||||||
int numOutputs() const override { return 1; }
|
int numOutputs() const override { return 1; }
|
||||||
|
int getAxis() const { return axis; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
|
|
||||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||||
};
|
};
|
||||||
|
|
|
@ -866,19 +866,23 @@ class OnnxStub:
|
||||||
):
|
):
|
||||||
tensors[name] = tensor
|
tensors[name] = tensor
|
||||||
elif node.op_type == "DequantizeLinear":
|
elif node.op_type == "DequantizeLinear":
|
||||||
attributes = _parse_attribute(
|
(inputX, inputScale) = (tensors[node.input[i]] for i in [0, 1])
|
||||||
node,
|
inputZeroPoint = (
|
||||||
{
|
None if len(node.input) < 3 else tensors[node.input[2]]
|
||||||
"axis": 1,
|
)
|
||||||
},
|
output = tensors.get(node.output[0])
|
||||||
|
axis = next(
|
||||||
|
(attr.i for attr in node.attribute if attr.name == "axis"),
|
||||||
|
0,
|
||||||
)
|
)
|
||||||
axis = attributes["axis"]
|
|
||||||
tensors[node.output[0]] = self.handler.dequantizeLinear(
|
tensors[node.output[0]] = self.handler.dequantizeLinear(
|
||||||
tensor[node.input[0]],
|
inputX,
|
||||||
tensor[node.input[1]],
|
inputScale,
|
||||||
tensor[node.input[2]] if len(node.input) > 2 else None,
|
output,
|
||||||
|
inputZeroPoint,
|
||||||
axis,
|
axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
new_node_name.append(node.name)
|
new_node_name.append(node.name)
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
#include "operators/broadcast.h"
|
#include "operators/broadcast.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
#include "operators/dequantize_linear.h"
|
#include "operators/dequantizeLinear.h"
|
||||||
#include "operators/dynamic_quantize_linear.h"
|
#include "operators/dynamic_quantize_linear.h"
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "operators/expand.h"
|
#include "operators/expand.h"
|
||||||
|
@ -521,18 +521,19 @@ GraphHandlerObj::dynamicQuantizeLinear(Tensor input,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::dequantizeLinear(Tensor input, Tensor scale,
|
Tensor GraphHandlerObj::dequantizeLinear(Tensor inputX, Tensor inputScale,
|
||||||
Tensor zero_point, Tensor output,
|
Tensor output, Tensor inputZeroPoint,
|
||||||
int axis) {
|
int axis) {
|
||||||
if (output) {
|
if (output) {
|
||||||
g->addOpWithOutputs<DequantizeLinearObj>(
|
g->addOpWithOutputs<DequantizeLinearObj>(
|
||||||
std::move(input), std::move(scale), std::move(zero_point), output,
|
std::move(inputX), std::move(inputScale), output,
|
||||||
axis);
|
std::move(inputZeroPoint), axis);
|
||||||
return output;
|
return output;
|
||||||
} else {
|
} else {
|
||||||
return g
|
return g
|
||||||
->addOp<DequantizeLinearObj>(std::move(input), std::move(scale),
|
->addOp<DequantizeLinearObj>(std::move(inputX),
|
||||||
std::move(zero_point), output, axis)
|
std::move(inputScale), output,
|
||||||
|
std::move(inputZeroPoint), axis)
|
||||||
->getOutput();
|
->getOutput();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -517,6 +517,7 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("expand", &Handler::expand, policy::move)
|
.def("expand", &Handler::expand, policy::move)
|
||||||
.def("erf", &Handler::erf, policy::move)
|
.def("erf", &Handler::erf, policy::move)
|
||||||
.def("where", &Handler::where, policy::move)
|
.def("where", &Handler::where, policy::move)
|
||||||
|
.def("dequantizeLinear", &Handler::dequantizeLinear, policy::move)
|
||||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||||
.def("optimize", &Handler::optimize, policy::automatic)
|
.def("optimize", &Handler::optimize, policy::automatic)
|
||||||
.def("operators", &Handler::operators, policy::move)
|
.def("operators", &Handler::operators, policy::move)
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
#include "operators/dequantizeLinear.h"
|
||||||
|
#include "cuda/cuda_dequantizeLinear.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class DequantizeLinearCuda : public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<DequantizeLinearObj>(_op);
|
||||||
|
|
||||||
|
void *const inputX = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const inputScale = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
const int axis = op->getAxis();
|
||||||
|
const int stride = op->getInputs(0)->getStride().at(axis);
|
||||||
|
|
||||||
|
auto dims = op->getInputs(0)->getDims();
|
||||||
|
int dimsize = dims[op->getAxis()];
|
||||||
|
int size = op->getOutput()->size();
|
||||||
|
|
||||||
|
if (op->getInputs(1)->getDType() == DataType::Float32) {
|
||||||
|
|
||||||
|
if (op->numInputs() == 3) {
|
||||||
|
void *const inputZeroPoint =
|
||||||
|
(op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
DequantizeLinearKernel((uint8_t *)inputX, (float *)inputScale,
|
||||||
|
(float *)output, dimsize, stride,
|
||||||
|
(uint8_t *)inputZeroPoint, size);
|
||||||
|
} else {
|
||||||
|
DequantizeLinearKernel((uint8_t *)inputX, (float *)inputScale,
|
||||||
|
(float *)output, dimsize, stride, size);
|
||||||
|
}
|
||||||
|
} else if (op->getInputs(1)->getDType() == DataType::Float16) {
|
||||||
|
if (op->numInputs() == 3) {
|
||||||
|
void *const inputZeroPoint =
|
||||||
|
(op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
DequantizeLinearKernel((uint8_t *)inputX, (half *)inputScale,
|
||||||
|
(half *)output, dimsize, stride,
|
||||||
|
(uint8_t *)inputZeroPoint, size);
|
||||||
|
} else {
|
||||||
|
DequantizeLinearKernel((uint8_t *)inputX, (half *)inputScale,
|
||||||
|
(half *)output, dimsize, stride, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::DequantizeLinear, DequantizeLinearCuda,
|
||||||
|
"DequantizeLinear_CUDA");
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,345 @@
|
||||||
|
#include "cuda/cuda_common.h"
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_DIM>
|
||||||
|
__launch_bounds__(BLOCK_DIM) __global__
|
||||||
|
void blockDequantizeLinearKernel(const uint8_t *inputX, const T *inputScale,
|
||||||
|
T *output, const int dimsize,
|
||||||
|
const int stride,
|
||||||
|
const uint8_t *inputZeroPoint) {
|
||||||
|
// len(scale) = len(bias) = dimsize
|
||||||
|
int tmp = blockIdx.x % stride;
|
||||||
|
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
||||||
|
int remain = dimsize % BLOCK_DIM;
|
||||||
|
int step = (dimsize - remain) / BLOCK_DIM + 1;
|
||||||
|
if (threadIdx.x < remain) {
|
||||||
|
for (int ind = 0; ind < step; ind++) {
|
||||||
|
output[tid + (threadIdx.x * step + ind) * stride] =
|
||||||
|
static_cast<T>(
|
||||||
|
inputX[tid + (threadIdx.x * step + ind) * stride] -
|
||||||
|
inputZeroPoint[threadIdx.x * step + ind]) *
|
||||||
|
inputScale[threadIdx.x * step + ind];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int ind = 0; ind < step - 1; ind++) {
|
||||||
|
output[tid +
|
||||||
|
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
|
||||||
|
stride] =
|
||||||
|
static_cast<T>(
|
||||||
|
inputX[tid + (remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) + ind) *
|
||||||
|
stride] -
|
||||||
|
inputZeroPoint[remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) + ind]) *
|
||||||
|
inputScale[remain * step + (threadIdx.x - remain) * (step - 1) +
|
||||||
|
ind];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename T, int BLOCK_DIM>
|
||||||
|
__launch_bounds__(BLOCK_DIM) __global__
|
||||||
|
void blockDequantizeLinearKernel(const uint8_t *inputX, const T *inputScale,
|
||||||
|
T *output, const int dimsize,
|
||||||
|
const int stride) {
|
||||||
|
// len(scale) = len(bias) = dimsize
|
||||||
|
int tmp = blockIdx.x % stride;
|
||||||
|
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
||||||
|
int remain = dimsize % BLOCK_DIM;
|
||||||
|
int step = (dimsize - remain) / BLOCK_DIM + 1;
|
||||||
|
if (threadIdx.x < remain) {
|
||||||
|
for (int ind = 0; ind < step; ind++) {
|
||||||
|
output[tid + (threadIdx.x * step + ind) * stride] =
|
||||||
|
static_cast<T>(
|
||||||
|
inputX[tid + (threadIdx.x * step + ind) * stride]) *
|
||||||
|
inputScale[threadIdx.x * step + ind];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int ind = 0; ind < step - 1; ind++) {
|
||||||
|
output[tid +
|
||||||
|
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
|
||||||
|
stride] =
|
||||||
|
static_cast<T>(
|
||||||
|
inputX[tid + (remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) + ind) *
|
||||||
|
stride]) *
|
||||||
|
inputScale[remain * step + (threadIdx.x - remain) * (step - 1) +
|
||||||
|
ind];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||||
|
__global__ void
|
||||||
|
warpDequantizeLinearKernel(const uint8_t *inputX, const T *inputScale,
|
||||||
|
T *output, const int dimsize, const int otherSize,
|
||||||
|
const int stride, const uint8_t *inputZeroPoint) {
|
||||||
|
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||||
|
int remain = dimsize % BLOCK_DIM_x;
|
||||||
|
int step = (dimsize - remain) / BLOCK_DIM_x + 1;
|
||||||
|
if (otherIdx < otherSize) {
|
||||||
|
if (threadIdx.x < remain) {
|
||||||
|
for (int ind = 0; ind < step; ind++) {
|
||||||
|
output[tid + (threadIdx.x * step + ind) * stride] =
|
||||||
|
static_cast<T>(
|
||||||
|
inputX[tid + (threadIdx.x * step + ind) * stride] -
|
||||||
|
inputZeroPoint[threadIdx.x * step + ind]) *
|
||||||
|
inputScale[threadIdx.x * step + ind];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int ind = 0; ind < step - 1; ind++) {
|
||||||
|
output[tid + (remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) + ind) *
|
||||||
|
stride] =
|
||||||
|
static_cast<T>(
|
||||||
|
inputX[tid +
|
||||||
|
(remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) + ind) *
|
||||||
|
stride] -
|
||||||
|
inputZeroPoint[remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) +
|
||||||
|
ind]) *
|
||||||
|
inputScale[remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) + ind];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||||
|
__global__ void
|
||||||
|
warpDequantizeLinearKernel(const uint8_t *inputX, const T *inputScale,
|
||||||
|
T *output, const int dimsize, const int otherSize,
|
||||||
|
const int stride) {
|
||||||
|
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||||
|
int remain = dimsize % BLOCK_DIM_x;
|
||||||
|
int step = (dimsize - remain) / BLOCK_DIM_x + 1;
|
||||||
|
if (otherIdx < otherSize) {
|
||||||
|
if (threadIdx.x < remain) {
|
||||||
|
for (int ind = 0; ind < step; ind++) {
|
||||||
|
output[tid + (threadIdx.x * step + ind) * stride] =
|
||||||
|
static_cast<T>(
|
||||||
|
inputX[tid + (threadIdx.x * step + ind) * stride]) *
|
||||||
|
inputScale[threadIdx.x * step + ind];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int ind = 0; ind < step - 1; ind++) {
|
||||||
|
output[tid + (remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) + ind) *
|
||||||
|
stride] =
|
||||||
|
static_cast<T>(
|
||||||
|
inputX[tid +
|
||||||
|
(remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) + ind) *
|
||||||
|
stride]) *
|
||||||
|
inputScale[remain * step +
|
||||||
|
(threadIdx.x - remain) * (step - 1) + ind];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
namespace infini {
|
||||||
|
void DequantizeLinearKernel(const uint8_t *inputX, const float *inputScale,
|
||||||
|
float *output, const int dimsize, const int stride,
|
||||||
|
const uint8_t *inputZeroPoint, const int size) {
|
||||||
|
|
||||||
|
int num_block = size / dimsize;
|
||||||
|
if (dimsize > 1024) {
|
||||||
|
int BLOCK_DIM = 1024;
|
||||||
|
|
||||||
|
blockDequantizeLinearKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
|
||||||
|
inputX, inputScale, output, dimsize, stride, inputZeroPoint);
|
||||||
|
} else if (dimsize > 31) {
|
||||||
|
int BLOCK_DIM_x = 32;
|
||||||
|
int BLOCK_DIM_y = 32;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<float, 32, 32>
|
||||||
|
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
|
||||||
|
num_block, stride, inputZeroPoint);
|
||||||
|
} else if (dimsize > 15) {
|
||||||
|
int BLOCK_DIM_x = 16;
|
||||||
|
int BLOCK_DIM_y = 64;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<float, 16, 64>
|
||||||
|
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
|
||||||
|
num_block, stride, inputZeroPoint);
|
||||||
|
} else if (dimsize > 7) {
|
||||||
|
int BLOCK_DIM_x = 8;
|
||||||
|
int BLOCK_DIM_y = 128;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<float, 8, 128>
|
||||||
|
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
|
||||||
|
num_block, stride, inputZeroPoint);
|
||||||
|
} else {
|
||||||
|
int BLOCK_DIM_x = 4;
|
||||||
|
int BLOCK_DIM_y = 256;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<float, 4, 256>
|
||||||
|
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
|
||||||
|
num_block, stride, inputZeroPoint);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DequantizeLinearKernel(const uint8_t *inputX, const float *inputScale,
|
||||||
|
float *output, const int dimsize, const int stride,
|
||||||
|
const int size) {
|
||||||
|
int num_block = size / dimsize;
|
||||||
|
if (dimsize > 1024) {
|
||||||
|
int BLOCK_DIM = 1024;
|
||||||
|
|
||||||
|
blockDequantizeLinearKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
|
||||||
|
inputX, inputScale, output, dimsize, stride);
|
||||||
|
} else if (dimsize > 31) {
|
||||||
|
int BLOCK_DIM_x = 32;
|
||||||
|
int BLOCK_DIM_y = 32;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<float, 32, 32><<<grid_dim, block_dim>>>(
|
||||||
|
inputX, inputScale, output, dimsize, num_block, stride);
|
||||||
|
} else if (dimsize > 15) {
|
||||||
|
int BLOCK_DIM_x = 16;
|
||||||
|
int BLOCK_DIM_y = 64;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<float, 16, 64><<<grid_dim, block_dim>>>(
|
||||||
|
inputX, inputScale, output, dimsize, num_block, stride);
|
||||||
|
} else if (dimsize > 7) {
|
||||||
|
int BLOCK_DIM_x = 8;
|
||||||
|
int BLOCK_DIM_y = 128;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<float, 8, 128><<<grid_dim, block_dim>>>(
|
||||||
|
inputX, inputScale, output, dimsize, num_block, stride);
|
||||||
|
} else {
|
||||||
|
int BLOCK_DIM_x = 4;
|
||||||
|
int BLOCK_DIM_y = 256;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<float, 4, 256><<<grid_dim, block_dim>>>(
|
||||||
|
inputX, inputScale, output, dimsize, num_block, stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//-------------
|
||||||
|
void DequantizeLinearKernel(const uint8_t *inputX, const half *inputScale,
|
||||||
|
half *output, const int dimsize, const int stride,
|
||||||
|
const uint8_t *inputZeroPoint, const int size) {
|
||||||
|
int num_block = size / dimsize;
|
||||||
|
if (dimsize > 1024) {
|
||||||
|
int BLOCK_DIM = 1024;
|
||||||
|
|
||||||
|
blockDequantizeLinearKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
|
||||||
|
inputX, inputScale, output, dimsize, stride, inputZeroPoint);
|
||||||
|
} else if (dimsize > 31) {
|
||||||
|
int BLOCK_DIM_x = 32;
|
||||||
|
int BLOCK_DIM_y = 32;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<half, 32, 32>
|
||||||
|
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
|
||||||
|
num_block, stride, inputZeroPoint);
|
||||||
|
} else if (dimsize > 15) {
|
||||||
|
int BLOCK_DIM_x = 16;
|
||||||
|
int BLOCK_DIM_y = 64;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<half, 16, 64>
|
||||||
|
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
|
||||||
|
num_block, stride, inputZeroPoint);
|
||||||
|
} else if (dimsize > 7) {
|
||||||
|
int BLOCK_DIM_x = 8;
|
||||||
|
int BLOCK_DIM_y = 128;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<half, 8, 128>
|
||||||
|
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
|
||||||
|
num_block, stride, inputZeroPoint);
|
||||||
|
} else {
|
||||||
|
int BLOCK_DIM_x = 4;
|
||||||
|
int BLOCK_DIM_y = 256;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<half, 4, 256>
|
||||||
|
<<<grid_dim, block_dim>>>(inputX, inputScale, output, dimsize,
|
||||||
|
num_block, stride, inputZeroPoint);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DequantizeLinearKernel(const uint8_t *inputX, const half *inputScale,
|
||||||
|
half *output, const int dimsize, const int stride,
|
||||||
|
const int size) {
|
||||||
|
int num_block = size / dimsize;
|
||||||
|
if (dimsize > 1024) {
|
||||||
|
int BLOCK_DIM = 1024;
|
||||||
|
|
||||||
|
blockDequantizeLinearKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
|
||||||
|
inputX, inputScale, output, dimsize, stride);
|
||||||
|
} else if (dimsize > 31) {
|
||||||
|
int BLOCK_DIM_x = 32;
|
||||||
|
int BLOCK_DIM_y = 32;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<half, 32, 32><<<grid_dim, block_dim>>>(
|
||||||
|
inputX, inputScale, output, dimsize, num_block, stride);
|
||||||
|
} else if (dimsize > 15) {
|
||||||
|
int BLOCK_DIM_x = 16;
|
||||||
|
int BLOCK_DIM_y = 64;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<half, 16, 64><<<grid_dim, block_dim>>>(
|
||||||
|
inputX, inputScale, output, dimsize, num_block, stride);
|
||||||
|
} else if (dimsize > 7) {
|
||||||
|
int BLOCK_DIM_x = 8;
|
||||||
|
int BLOCK_DIM_y = 128;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<half, 8, 128><<<grid_dim, block_dim>>>(
|
||||||
|
inputX, inputScale, output, dimsize, num_block, stride);
|
||||||
|
} else {
|
||||||
|
int BLOCK_DIM_x = 4;
|
||||||
|
int BLOCK_DIM_y = 256;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpDequantizeLinearKernel<half, 4, 256><<<grid_dim, block_dim>>>(
|
||||||
|
inputX, inputScale, output, dimsize, num_block, stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -1,13 +1,15 @@
|
||||||
#include "operators/dequantize_linear.h"
|
#include "operators/dequantizeLinear.h"
|
||||||
#include "utils/operator_utils.h"
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
DequantizeLinearObj::DequantizeLinearObj(GraphObj *graph, Tensor input,
|
|
||||||
Tensor scale, Tensor zero_point,
|
DequantizeLinearObj::DequantizeLinearObj(GraphObj *graph, Tensor inputX,
|
||||||
Tensor output, int axis)
|
Tensor inputScale, Tensor output,
|
||||||
|
[[maybe_unused]] Tensor inputZeroPoint,
|
||||||
|
int axis)
|
||||||
: OperatorObj(OpType::DequantizeLinear,
|
: OperatorObj(OpType::DequantizeLinear,
|
||||||
zero_point ? TensorVec{input, scale, zero_point}
|
inputZeroPoint ? TensorVec{inputX, inputScale, inputZeroPoint}
|
||||||
: TensorVec{input, scale},
|
: TensorVec{inputX, inputScale},
|
||||||
{output}),
|
{output}),
|
||||||
axis(axis) {
|
axis(axis) {
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
|
@ -15,13 +17,14 @@ DequantizeLinearObj::DequantizeLinearObj(GraphObj *graph, Tensor input,
|
||||||
|
|
||||||
optional<vector<Shape>>
|
optional<vector<Shape>>
|
||||||
DequantizeLinearObj::inferShape(const TensorVec &inputs) {
|
DequantizeLinearObj::inferShape(const TensorVec &inputs) {
|
||||||
return {{inputs[0]->getDims()}};
|
return {{inputs[0]->getDims()}}; // x.shape = output.shape = inputs[0].shape
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<DataType>
|
vector<DataType>
|
||||||
DequantizeLinearObj::inferDataType(const TensorVec &inputs) const {
|
DequantizeLinearObj::inferDataType(const TensorVec &inputs) const {
|
||||||
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
|
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
|
||||||
return {inputs[1]->getDType()};
|
|
||||||
|
return {
|
||||||
|
inputs[1]->getDType()}; // scale.dtype = output.dtype = inputs[1].dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string DequantizeLinearObj::toString() const {
|
std::string DequantizeLinearObj::toString() const {
|
||||||
|
@ -29,12 +32,11 @@ std::string DequantizeLinearObj::toString() const {
|
||||||
os << "DequantizeLinear[" << getGuid() << "]";
|
os << "DequantizeLinear[" << getGuid() << "]";
|
||||||
os << "(";
|
os << "(";
|
||||||
os << vecToString(inputs[0]->getDims()) << ",";
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
os << "input=" << inputs[0]->getGuid() << ",";
|
os << "inputX=" << inputs[0]->getGuid() << ",";
|
||||||
os << "scale=" << inputs[1]->getGuid() << ",";
|
os << "inputScale=" << inputs[1]->getGuid() << ",";
|
||||||
|
// os << "inputZeroPoint=" << inputs[2]->getGuid() << ",";
|
||||||
os << "axis=" << axis << ",";
|
os << "axis=" << axis << ",";
|
||||||
os << "output=";
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
for (auto output : outputs)
|
|
||||||
os << output->getGuid() << ",";
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,7 +47,7 @@ vector<int> DequantizeLinearObj::getWorkloadVector() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<int> DequantizeLinearObj::getOpAttrVector() const {
|
vector<int> DequantizeLinearObj::getOpAttrVector() const {
|
||||||
return {type.underlying()};
|
return {type.underlying(), axis};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
|
@ -0,0 +1,219 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/dequantizeLinear.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void test_dequantizeLinearFp32(
|
||||||
|
const Shape &inputXShape, const vector<uint8_t> &inputXData,
|
||||||
|
const Shape &inputScaleShape, const vector<float> &inputScaleData, int axis,
|
||||||
|
const vector<float> &ExpectData,
|
||||||
|
const std::optional<Shape> &zeroPointShape = std::nullopt,
|
||||||
|
const std::optional<std::vector<uint8_t>> &inputZeroPointData =
|
||||||
|
std::nullopt) {
|
||||||
|
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
if (zeroPointShape.has_value() && inputZeroPointData.has_value()) {
|
||||||
|
Shape inputZeroPointShape = *zeroPointShape;
|
||||||
|
|
||||||
|
auto inputZeroPoint =
|
||||||
|
gCpu->addTensor(inputZeroPointShape, DataType::UInt8);
|
||||||
|
auto inputX = gCpu->addTensor(inputXShape, DataType::UInt8);
|
||||||
|
auto inputScale = gCpu->addTensor(inputScaleShape, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
inputZeroPoint->copyin(*inputZeroPointData); //
|
||||||
|
inputX->copyin(inputXData);
|
||||||
|
inputScale->copyin(inputScaleData); //
|
||||||
|
|
||||||
|
// inputX->printData();
|
||||||
|
// inputZeroPoint->printData();
|
||||||
|
// inputScale->printData();
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto inputZeroPointGpu = gCuda->cloneTensor(inputZeroPoint);
|
||||||
|
auto inputXGpu = gCuda->cloneTensor(inputX);
|
||||||
|
auto inputScaleGpu = gCuda->cloneTensor(inputScale);
|
||||||
|
|
||||||
|
auto op = gCuda->addOp<DequantizeLinearObj>(
|
||||||
|
inputXGpu, inputScaleGpu, nullptr, inputZeroPointGpu,
|
||||||
|
axis); // DequantizeLinearObj
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputZeroPointGpu->copyin(*inputZeroPointData);
|
||||||
|
// gCpu->cloneTensor(inputZeroPointGpu)->printData();
|
||||||
|
inputXGpu->copyin(inputXData);
|
||||||
|
inputScaleGpu->copyin(inputScaleData);
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu =
|
||||||
|
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||||
|
oCpu->printData(); //->printData
|
||||||
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
|
} else {
|
||||||
|
|
||||||
|
auto inputX = gCpu->addTensor(inputXShape, DataType::UInt8);
|
||||||
|
auto inputScale = gCpu->addTensor(inputScaleShape, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
|
||||||
|
inputX->copyin(inputXData);
|
||||||
|
inputScale->copyin(inputScaleData); //
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto inputXGpu = gCuda->cloneTensor(inputX);
|
||||||
|
auto inputScaleGpu = gCuda->cloneTensor(inputScale);
|
||||||
|
auto op = gCuda->addOp<DequantizeLinearObj>(
|
||||||
|
inputXGpu, inputScaleGpu, nullptr, nullptr,
|
||||||
|
axis); // DequantizeLinearObj
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
|
||||||
|
inputXGpu->copyin(inputXData);
|
||||||
|
inputScaleGpu->copyin(inputScaleData);
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu =
|
||||||
|
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||||
|
oCpu->printData(); //->printData
|
||||||
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void test_dequantizeLinearFp16(
|
||||||
|
const Shape &inputXShape, const vector<uint8_t> &inputXData,
|
||||||
|
const Shape &inputScaleShape,
|
||||||
|
const std::function<void(void *, size_t, DataType)> &generator, int axis,
|
||||||
|
const vector<float> &ExpectData,
|
||||||
|
const std::optional<Shape> &zeroPointShape = std::nullopt,
|
||||||
|
const std::optional<std::vector<uint8_t>> &inputZeroPointData =
|
||||||
|
std::nullopt) {
|
||||||
|
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
if (zeroPointShape.has_value() && inputZeroPointData.has_value()) {
|
||||||
|
Shape inputZeroPointShape = *zeroPointShape;
|
||||||
|
|
||||||
|
auto inputZeroPoint =
|
||||||
|
gCpu->addTensor(inputZeroPointShape, DataType::UInt8);
|
||||||
|
auto inputX = gCpu->addTensor(inputXShape, DataType::UInt8);
|
||||||
|
auto inputScale = gCpu->addTensor(inputScaleShape, DataType::Float16);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
inputZeroPoint->copyin(*inputZeroPointData); //
|
||||||
|
// inputZeroPoint->printData();
|
||||||
|
inputX->copyin(inputXData);
|
||||||
|
inputScale->setData(generator);
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto inputZeroPointGpu = gCuda->cloneTensor(inputZeroPoint);
|
||||||
|
auto inputXGpu = gCuda->cloneTensor(inputX);
|
||||||
|
auto inputScaleGpu = gCuda->cloneTensor(inputScale);
|
||||||
|
// gCpu->cloneTensor(inputZeroPointGpu)->printData();
|
||||||
|
auto op = gCuda->addOp<DequantizeLinearObj>(
|
||||||
|
inputXGpu, inputScaleGpu, nullptr, inputZeroPointGpu,
|
||||||
|
axis); // DequantizeLinearObj
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputZeroPointGpu->copyin(*inputZeroPointData);
|
||||||
|
// gCpu->cloneTensor(inputZeroPointGpu)->printData();
|
||||||
|
inputXGpu->copyin(inputXData);
|
||||||
|
inputScaleGpu->setData(generator);
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu =
|
||||||
|
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||||
|
oCpu->printData(); //->printData
|
||||||
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
|
} else {
|
||||||
|
|
||||||
|
auto inputX = gCpu->addTensor(inputXShape, DataType::UInt8);
|
||||||
|
auto inputScale = gCpu->addTensor(inputScaleShape, DataType::Float16);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
|
||||||
|
inputX->copyin(inputXData);
|
||||||
|
inputScale->setData(generator);
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto inputXGpu = gCuda->cloneTensor(inputX);
|
||||||
|
auto inputScaleGpu = gCuda->cloneTensor(inputScale);
|
||||||
|
auto op = gCuda->addOp<DequantizeLinearObj>(
|
||||||
|
inputXGpu, inputScaleGpu, nullptr, nullptr,
|
||||||
|
axis); // DequantizeLinearObj
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
|
||||||
|
inputXGpu->copyin(inputXData);
|
||||||
|
inputScaleGpu->setData(generator);
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu =
|
||||||
|
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||||
|
oCpu->printData(); //->printData
|
||||||
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_DequantizeLinearFp32, run) {
|
||||||
|
|
||||||
|
test_dequantizeLinearFp32(
|
||||||
|
Shape{2, 3, 2, 3},
|
||||||
|
vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||||
|
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
|
||||||
|
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1,
|
||||||
|
vector<float>{-0.3000000, 0.0000000, 0.3000000, 0.6000000,
|
||||||
|
0.9000000, 1.2000000, 0.8000000, 1.0000000,
|
||||||
|
1.2000000, 1.4000000, 1.6000000, 1.8000001,
|
||||||
|
4.5000000, 5.0000000, 5.5000000, 6.0000000,
|
||||||
|
6.5000000, 7.0000000, 5.1000004, 5.4000001,
|
||||||
|
5.7000003, 6.0000000, 6.3000002, 6.6000004,
|
||||||
|
4.4000001, 4.5999999, 4.8000002, 5.0000000,
|
||||||
|
5.2000003, 5.4000001, 13.5000000, 14.0000000,
|
||||||
|
14.5000000, 15.0000000, 15.5000000, 16.0000000},
|
||||||
|
Shape{3}, vector<uint8_t>{1, 2, 3});
|
||||||
|
test_dequantizeLinearFp32(
|
||||||
|
Shape{2, 3, 2, 3},
|
||||||
|
vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||||
|
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
|
||||||
|
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1,
|
||||||
|
vector<float>{0.0000000, 0.3000000, 0.6000000, 0.9000000,
|
||||||
|
1.2000000, 1.5000000, 1.2000000, 1.4000000,
|
||||||
|
1.6000000, 1.8000001, 2.0000000, 2.2000000,
|
||||||
|
6.0000000, 6.5000000, 7.0000000, 7.5000000,
|
||||||
|
8.0000000, 8.5000000, 5.4000001, 5.7000003,
|
||||||
|
6.0000000, 6.3000002, 6.6000004, 6.9000001,
|
||||||
|
4.8000002, 5.0000000, 5.2000003, 5.4000001,
|
||||||
|
5.5999999, 5.8000002, 15.0000000, 15.5000000,
|
||||||
|
16.0000000, 16.5000000, 17.0000000, 17.5000000});
|
||||||
|
|
||||||
|
} // python output
|
||||||
|
TEST(CUDA_DequantizeLinearFp16, run) {
|
||||||
|
test_dequantizeLinearFp16(
|
||||||
|
Shape{2, 3, 2, 3},
|
||||||
|
vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||||
|
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
|
||||||
|
Shape{3}, ValGenerator<2>(), 1,
|
||||||
|
vector<float>{-2., 0., 2., 4., 6., 8., 8., 10., 12.,
|
||||||
|
14., 16., 18., 18., 20., 22., 24., 26., 28.,
|
||||||
|
34., 36., 38., 40., 42., 44., 44., 46., 48.,
|
||||||
|
50., 52., 54., 54., 56., 58., 60., 62., 64.},
|
||||||
|
Shape{3}, vector<uint8_t>{1, 2, 3});
|
||||||
|
test_dequantizeLinearFp16(
|
||||||
|
Shape{2, 3, 2, 3},
|
||||||
|
vector<uint8_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||||
|
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
|
||||||
|
Shape{3}, ValGenerator<2>(), 1,
|
||||||
|
vector<float>{0., 2., 4., 6., 8., 10., 12., 14., 16.,
|
||||||
|
18., 20., 22., 24., 26., 28., 30., 32., 34.,
|
||||||
|
36., 38., 40., 42., 44., 46., 48., 50., 52.,
|
||||||
|
54., 56., 58., 60., 62., 64., 66., 68., 70.});
|
||||||
|
|
||||||
|
} // python output
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue