feat: add matmulinteger op

This commit is contained in:
zhangyunze 2023-12-19 09:57:31 +08:00
parent 9c82936386
commit 97e3377ca5
15 changed files with 478 additions and 8 deletions

View File

@ -103,6 +103,8 @@ class GraphHandlerObj {
std::string mode);
TensorVec dynamicQuantizeLinear(Tensor input,
std::optional<TensorVec> outputs);
Tensor matmulInteger(Tensor inputA, Tensor inputB, Tensor output,
Tensor a_zero_point, Tensor b_zero_point);
//------ modifiers

View File

@ -0,0 +1,7 @@
#pragma once
namespace infini {
void subA_kernel(int dType, void *a, void *b, int size, int k, int delta);
void subB_kernel(int dType, void *a, void *b, int size, int k, int n,
int delta);
}; // namespace infini

View File

@ -0,0 +1,63 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Matrix multiplication.
*
*/
class MatmulIntegerObj : public OperatorObj {
private:
// Auxiliary attributes which are not a part of operator attributes.
int b, m, n, k;
public:
/**
* @brief Matmul operator with batch broadcast and tensor transpose
* supports. Only one tensor with singe batch can be broadcasted due to the
* BLAS interface restriction. Tranpose indicates whether the last two
* dimensions should be transposed before Matmul and does not affect other
* leading dimensions.
*
* Matmul show how operators are defined in InfiniTensor. The constructor of
* an operator can create output tensors for the operator or not, which
* depends on `graph`.
*
* @param graph The computation graph that this operator belongs to.
* @param A The input tensor.
* @param B The input tensor.
* @param C C is the output of Matmul. If outputs are going to be created in
* the constructor, C should be an empty Ref.
* @param a_zero_point Zero point tensor for input 'A'.
* @param b_zero_point Zero point tensor for input 'B'.
*/
MatmulIntegerObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
Tensor a_zero_point = nullptr,
Tensor b_zero_point = nullptr);
OP_CLONE(MatmulIntegerObj);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }
Tensor getZeroPointA() const {
return inputs.size() > 2 ? inputs[2] : nullptr;
}
Tensor getZeroPointB() const {
return inputs.size() > 3 ? inputs[3] : nullptr;
}
int getB() const { return b; }
int getM() const { return m; }
int getN() const { return n; }
int getK() const { return k; }
auto getBMNK() const { return tuple{b, m, n, k}; }
vector<DataType> inferDataType(const TensorVec &inputs) const override;
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -882,7 +882,14 @@ class OnnxStub:
inputZeroPoint,
axis,
)
elif node.op_type == "MatMulInteger":
tensors[node.output[0]] = self.handler.matmulInteger(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
tensors[node.input[2]] if len(node.input) > 2 else None,
tensors[node.input[3]] if len(node.input) > 3 else None,
)
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
new_node_name.append(node.name)

View File

@ -102,6 +102,18 @@ class TestStringMethods(unittest.TestCase):
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa]))
def test_matmul_integer(self):
A = make_tensor_value_info("A", TensorProto.INT8, [1, 2, 4])
B = make_tensor_value_info("B", TensorProto.UINT8, [1, 4, 4])
A_ZeroPoint = make_tensor_value_info("A_ZeroPoint", TensorProto.INT8, [1, 2, 1])
y = make_tensor_value_info("y", TensorProto.INT32, [1, 2, 4])
matmulInteger = make_node(
"MatMulInteger", ["A", "B", "A_ZeroPoint"], ["y"], name="matmul_integer"
)
make_and_import_model(
make_graph([matmulInteger], "matmul_integer", [A, B, A_ZeroPoint], [y])
)
def test_gemm(self):
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 2, 3])
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 4, 3])

View File

@ -13,6 +13,7 @@
#include "operators/gather.h"
#include "operators/layer_norm.h"
#include "operators/matmul.h"
#include "operators/matmul_integer.h"
#include "operators/pad.h"
#include "operators/pooling.h"
#include "operators/recv.h"
@ -147,6 +148,23 @@ Tensor GraphHandlerObj::avgPool(Tensor input, Tensor output, int kh, int kw,
}
}
Tensor GraphHandlerObj::matmulInteger(Tensor inputA, Tensor inputB,
Tensor output, Tensor a_zero_point,
Tensor b_zero_point) {
if (output) {
g->addOpWithOutputs<MatmulIntegerObj>(
std::move(inputA), std::move(inputB), output,
std::move(a_zero_point), std::move(b_zero_point));
return output;
} else {
return g
->addOp<MatmulIntegerObj>(std::move(inputA), std::move(inputB),
output, std::move(a_zero_point),
std::move(b_zero_point))
->getOutput();
}
}
// see operators/element_wise.h
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \

View File

@ -58,12 +58,13 @@ HashType OperatorObj::hash() const {
bool OperatorObj::checkValid(GraphObj *graph) {
auto optShapes = inferShape();
if (!optShapes) // shape inference failed
if (!optShapes) { // shape inference failed
return false;
}
const vector<Shape> &shapes = *optShapes;
if (shapes.size() != outputs.size())
if (shapes.size() != outputs.size()) {
return false;
}
if (graph) { // if graph != nullptr, outputs should be created
auto dataTypes = inferDataType();
for (size_t i = 0; i < outputs.size(); i++) {

View File

@ -59,8 +59,8 @@ cudaDataType cublasDataTypeConvert(DataType dataType) {
switch (dataType.getIndex()) {
case 1:
return CUDA_R_32F;
// case 3:
// return CUDA_R_8I;
// case 3:
// return CUDA_R_8I;
case 10:
return CUDA_R_16F;
case 11:

View File

@ -518,6 +518,7 @@ void init_graph_builder(py::module &m) {
.def("erf", &Handler::erf, policy::move)
.def("where", &Handler::where, policy::move)
.def("dequantizeLinear", &Handler::dequantizeLinear, policy::move)
.def("matmulInteger", &Handler::matmulInteger, policy::move)
.def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("optimize", &Handler::optimize, policy::automatic)
.def("operators", &Handler::operators, policy::move)

View File

@ -52,11 +52,10 @@ class matmulCublas : public Kernel {
float alpha = 1.f, beta = 0.f;
auto dataType = op->getDType();
auto cuDataType = cublasDataTypeConvert(dataType);
IT_ASSERT(cuDataType != CUDA_R_8I, "matmul don't support int8 dtype.");
if (op->numInputs() == 2) { // no bias
beta = 0.f;
} else { // broadcast bias to output
// IT_ASSERT(cuDataType != CUDA_R_8I,
// "MatMul bias don't support INT8.");
beta = 1.f;
auto inC = op->getInputs(2);
auto out = op->getOutput();

View File

@ -0,0 +1,109 @@
#include "operators/matmul_integer.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_matmul_integer.h"
#include "cuda/cuda_runtime.h"
#include "utils/small_array.h"
#include <thrust/transform.h>
namespace infini {
class matmulIntegerCublas : public CudaKernelWithoutConfig {
bool do_compute(const Operator &_op, const RuntimeObj *_context) const {
auto op = as<MatmulIntegerObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *const inAData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inBData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
const auto [b, m, n, k] = op->getBMNK();
if (op->numInputs() >= 3) { // have a_zero_point
int aZeroSize = op->getInputs(2)->size();
int aSize = op->getInputs(0)->size();
void *const aZeroPointData =
(op->getInputs(2)->getRawDataPtr<void *>());
if (op->getInputs(0)->getDType() == DataType::Int8) {
if (aZeroSize > 1) {
subA_kernel(DataType::Int8.getIndex(), inAData,
aZeroPointData, aSize, k, 1);
} else {
subA_kernel(DataType::Int8.getIndex(), inAData,
aZeroPointData, aSize, k, 0);
}
}
if (op->getInputs(0)->getDType() == DataType::UInt8) {
if (aZeroSize > 1) {
subA_kernel(DataType::UInt8.getIndex(), inAData,
aZeroPointData, aSize, k, 1);
} else {
subA_kernel(DataType::UInt8.getIndex(), inAData,
aZeroPointData, aSize, k, 0);
}
}
}
if (op->numInputs() == 4) { // have b_zero_point
int bZeroSize = op->getInputs(3)->size();
int bSize = op->getInputs(1)->size();
void *const bZeroPointData =
(op->getInputs(3)->getRawDataPtr<void *>());
if (op->getInputs(1)->getDType() == DataType::Int8) {
if (bZeroSize > 1) {
subB_kernel(DataType::Int8.getIndex(), inBData,
bZeroPointData, bSize, k, n, 1);
} else {
subB_kernel(DataType::Int8.getIndex(), inBData,
bZeroPointData, bSize, k, n, 0);
}
}
if (op->getInputs(1)->getDType() == DataType::UInt8) {
if (bZeroSize > 1) {
subB_kernel(DataType::UInt8.getIndex(), inBData,
bZeroPointData, bSize, k, n, 1);
} else {
subB_kernel(DataType::UInt8.getIndex(), inBData,
bZeroPointData, bSize, k, n, 0);
}
}
}
int lda = k, ldb = n, ldc = n;
int32_t alpha = 1, beta = 0;
// TODO:use compute type
cublasStatus_t stat;
if (b > 1) {
// Support batch broadcast with zero stride
int dimA = op->getInputs(0)->getRank();
int dimB = op->getInputs(1)->getRank();
long long strideA =
(dimA == 2 ||
(dimA == 3 && op->getInputs(0)->getDims()[0] == 1))
? 0 // Broadcast the batch dimension if batch size is 1
: m * k;
long long strideB =
(dimB == 2 ||
(dimB == 3 && op->getInputs(1)->getDims()[0] == 1))
? 0 // Broadcast the batch dimension if batch size is 1
: n * k;
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
&alpha, inBData, CUDA_R_8I, ldb, strideB, inAData, CUDA_R_8I,
lda, strideA, &beta, outData, CUDA_R_32I, ldc, m * n, b,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
} else {
stat = cublasGemmEx(
context->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
&alpha, inBData, CUDA_R_8I, ldb, inAData, CUDA_R_8I, lda, &beta,
outData, CUDA_R_32I, ldc, CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
}
return (stat == CUBLAS_STATUS_SUCCESS);
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
IT_ASSERT(do_compute(_op, _context));
}
};
REGISTER_KERNEL(Device::CUDA, OpType::MatMulInteger, matmulIntegerCublas,
"MatmulInteger_cuBLAS_CUDA");
}; // namespace infini

View File

@ -0,0 +1,75 @@
#include "cuda/cuda_common.h"
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _subA_kernel(void *a, void *b, int size, int k, int delta) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < size; i += stride) {
int j = delta * (i - i % k) / k;
((int8_t *)a)[i] = ((int8_t *)a)[i] - ((int8_t *)b)[j];
}
}
__global__ void _subA_u8_kernel(void *a, void *b, int size, int k, int delta) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < size; i += stride) {
int j = delta * (i - i % k) / k;
auto aData = static_cast<int16_t>(((uint8_t *)a)[i]);
auto bData = static_cast<int16_t>(((uint8_t *)b)[j]);
((int8_t *)a)[i] = static_cast<int8_t>(aData - bData);
}
}
__global__ void _subB_kernel(void *a, void *b, int size, int k, int n,
int delta) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < size; i += stride) {
int j = delta * (i / k) + (i % n);
((int8_t *)a)[i] = ((int8_t *)a)[i] - ((int8_t *)b)[j];
}
}
__global__ void _subB_u8_kernel(void *a, void *b, int size, int k, int n,
int delta) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < size; i += stride) {
int j = delta * (i / k) + (i % n);
auto aData = static_cast<int16_t>(((uint8_t *)a)[i]);
auto bData = static_cast<int16_t>(((uint8_t *)b)[j]);
((int8_t *)a)[i] = static_cast<int8_t>(aData - bData);
}
}
namespace infini {
void subA_kernel(int dType, void *a, void *b, int size, int k, int delta) {
int blocksize = block_work_size();
int gridsize = (size + block_work_size() - 1) / block_work_size();
if (dType == 3) {
_subA_kernel<<<gridsize, blocksize>>>(a, b, size, k, delta);
} else if (dType == 2) {
_subA_u8_kernel<<<gridsize, blocksize>>>(a, b, size, k, delta);
} else {
IT_TODO_HALT();
}
}
void subB_kernel(int dType, void *a, void *b, int size, int k, int n,
int delta) {
int blocksize = block_work_size();
int gridsize = (size + block_work_size() - 1) / block_work_size();
if (dType == 3) {
_subB_kernel<<<gridsize, blocksize>>>(a, b, size, k, n, delta);
} else if (dType == 2) {
_subB_u8_kernel<<<gridsize, blocksize>>>(a, b, size, k, n, delta);
} else {
IT_TODO_HALT();
}
}
}; // namespace infini

View File

@ -0,0 +1,75 @@
#include "operators/matmul_integer.h"
#include "utils/operator_utils.h"
#include <numeric>
namespace infini {
MatmulIntegerObj::MatmulIntegerObj(GraphObj *graph, Tensor A, Tensor B,
Tensor C,
[[maybe_unused]] Tensor a_zero_point,
[[maybe_unused]] Tensor b_zero_point)
: OperatorObj(OpType::MatMulInteger,
a_zero_point ? (b_zero_point ? TensorVec{A, B, a_zero_point,
b_zero_point}
: TensorVec{A, B, a_zero_point})
: TensorVec{A, B},
{C}),
b(1) {
IT_ASSERT(checkValid(graph));
}
string MatmulIntegerObj::toString() const {
std::ostringstream os;
os << "MatmulInteger(A=" << inputs[0]->getGuid()
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])";
return os.str();
}
optional<vector<Shape>> MatmulIntegerObj::inferShape(const TensorVec &inputs) {
auto A = inputs[0], B = inputs[1];
auto shapeA = A->getDims();
auto shapeB = B->getDims();
int rankA = A->getRank();
int rankB = B->getRank();
Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2));
Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2));
Shape ret = infer_broadcast(shapeA1, shapeB1);
if (ret.empty()) {
b = 1;
} else {
b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies<int>());
}
IT_ASSERT(*(shapeA.rbegin()) == *(shapeB.rbegin() + 1));
m = *(shapeA.rbegin() + 1);
n = *(shapeB.rbegin());
k = *(shapeA.rbegin());
ret.emplace_back(m);
ret.emplace_back(n);
return {{ret}};
}
vector<DataType>
MatmulIntegerObj::inferDataType(const TensorVec &inputs) const {
for (auto &input : inputs) {
IT_ASSERT(input->getDType() == DataType::Int8 ||
input->getDType() == DataType::UInt8);
}
if (inputs.size() >= 3) {
IT_ASSERT(inputs[0]->getDType() == inputs[2]->getDType());
}
if (inputs.size() == 4) {
IT_ASSERT(inputs[1]->getDType() == inputs[3]->getDType());
}
return vector(numOutputs(), DataType::Int32);
}
vector<int> MatmulIntegerObj::getWorkloadVector() const {
return {type.underlying(), b, m, n, k};
}
vector<int> MatmulIntegerObj::getOpAttrVector() const {
return {type.underlying()};
}
} // namespace infini

View File

@ -0,0 +1,68 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "operators/matmul_integer.h"
#include "test.h"
namespace infini {
using ExpectOutput = vector<int32_t>;
TEST(cuBLAS_MatmulInteger, ZeroPoint1) {
auto cudaRuntime = make_ref<CudaRuntimeObj>();
auto gCuda = make_ref<GraphObj>(cudaRuntime);
auto ACuda = gCuda->addTensor({1, 4}, DataType::UInt8);
auto BCuda = gCuda->addTensor({4, 12}, DataType::UInt8);
auto AZeroPointCuda = gCuda->addTensor({}, DataType::UInt8);
auto BZeroPointCuda = gCuda->addTensor({}, DataType::UInt8);
auto op = gCuda->addOp<MatmulIntegerObj>(ACuda, BCuda, nullptr,
AZeroPointCuda, BZeroPointCuda);
// allocate CUDA memory
gCuda->dataMalloc();
// ACuda->copyin(vector<uint8_t>{11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0});
ACuda->copyin(vector<uint8_t>{11, 7, 3, 10});
// BCuda->copyin(vector<uint8_t>({1, 4, 2, 5, 3, 6,}));
BCuda->copyin(vector<uint8_t>(48, 1));
AZeroPointCuda->copyin(vector<uint8_t>{12});
BZeroPointCuda->copyin(vector<uint8_t>{0});
cudaRuntime->run(gCuda);
auto result = op->getOutput()->clone(NativeCpuRuntimeObj::getInstance());
// ExpectOutput ans = {
// -38, -83, -44, -98, -50, -113, -56, -128,
// };
ExpectOutput ans = {-17, -17, -17, -17, -17, -17,
-17, -17, -17, -17, -17, -17};
EXPECT_TRUE(result->equalData(ans));
}
TEST(cuBLAS_MatmulInteger, ZeroPoint2) {
auto cudaRuntime = make_ref<CudaRuntimeObj>();
auto gCuda = make_ref<GraphObj>(cudaRuntime);
auto ACuda = gCuda->addTensor({2, 3, 1, 4}, DataType::UInt8);
auto BCuda = gCuda->addTensor({2, 3, 4, 12}, DataType::UInt8);
auto AZeroPointCuda = gCuda->addTensor({2, 3, 1, 1}, DataType::UInt8);
auto BZeroPointCuda = gCuda->addTensor({2, 3, 1, 12}, DataType::UInt8);
auto op = gCuda->addOp<MatmulIntegerObj>(ACuda, BCuda, nullptr,
AZeroPointCuda, BZeroPointCuda);
// allocate CUDA memory
gCuda->dataMalloc();
ACuda->copyin(vector<uint8_t>{11, 7, 3, 10, 11, 7, 3, 10, 11, 7, 3, 10,
11, 7, 3, 10, 11, 7, 3, 10, 11, 7, 3, 10});
BCuda->copyin(vector<uint8_t>(288, 1));
AZeroPointCuda->copyin(vector<uint8_t>(6, 12));
BZeroPointCuda->copyin(vector<uint8_t>(72, 0));
cudaRuntime->run(gCuda);
auto result = op->getOutput()->clone(NativeCpuRuntimeObj::getInstance());
ExpectOutput ans = {-17, -17, -17, -17, -17, -17, -17, -17, -17, -17, -17,
-17, -17, -17, -17, -17, -17, -17, -17, -17, -17, -17,
-17, -17, -17, -17, -17, -17, -17, -17, -17, -17, -17,
-17, -17, -17, -17, -17, -17, -17, -17, -17, -17, -17,
-17, -17, -17, -17, -17, -17, -17, -17, -17, -17, -17,
-17, -17, -17, -17, -17, -17, -17, -17, -17, -17, -17,
-17, -17, -17, -17, -17, -17};
EXPECT_TRUE(result->equalData(ans));
}
}; // namespace infini

View File

@ -0,0 +1,33 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/matmul_integer.h"
#include "test.h"
namespace infini {
using ExpectOutput = vector<float>;
TEST(MatmulInteger, ShapeInference) {
auto runtime = NativeCpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime);
auto A = g->addTensor(Shape{1, 4, 2}, DataType::Int8);
auto B = g->addTensor(Shape{1, 2, 12}, DataType::Int8);
auto op = g->addOp<MatmulIntegerObj>(A, B, nullptr, nullptr, nullptr);
auto C = op->getOutputs()[0];
EXPECT_EQ(C->getDims(), (Shape{1, 4, 12}));
}
{
Graph g = make_ref<GraphObj>(runtime);
auto A = g->addTensor(Shape{1, 4, 2}, DataType::UInt8);
auto B = g->addTensor(Shape{1, 2, 12}, DataType::UInt8);
auto A_Zero = g->addTensor(Shape{1, 4, 1}, DataType::UInt8);
auto B_Zero = g->addTensor(Shape{1, 1, 12}, DataType::UInt8);
auto op = g->addOp<MatmulIntegerObj>(A, B, nullptr, A_Zero, B_Zero);
auto C = op->getOutputs()[0];
EXPECT_EQ(C->getDims(), (Shape{1, 4, 12}));
}
}
}; // namespace infini