feature: add parameter to config matmul compute type (#218)

* feature: add parameter to config matmul compute type

* fix format
This commit is contained in:
Chenjie Duan 2024-03-26 09:00:45 +08:00 committed by GitHub
parent 00e6cc2587
commit 54a35772fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 67 additions and 16 deletions

View File

@ -30,7 +30,8 @@ class GraphHandlerObj {
int pw, int sh, int sw, int dh, int dw, int oph, int pw, int sh, int sw, int dh, int dw, int oph,
int opw); int opw);
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
Tensor bias, ActType act); Tensor bias, ActType act,
std::string matmul_compute_type = "default");
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean, Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
Tensor var, Tensor scale, Tensor bias, Tensor var, Tensor scale, Tensor bias,
float momentum, float eps, bool training); float momentum, float eps, bool training);

View File

@ -17,6 +17,9 @@ class MatmulObj : public OperatorObj {
// Auxiliary attributes which are not a part of operator attributes. // Auxiliary attributes which are not a part of operator attributes.
int b, m, n, k; int b, m, n, k;
// Specifies the data precision for the matrix multiply.
std::string computeType = "default";
public: public:
/** /**
* @brief Matmul operator with batch broadcast and tensor transpose * @brief Matmul operator with batch broadcast and tensor transpose
@ -38,10 +41,11 @@ class MatmulObj : public OperatorObj {
* @param transB If matrix B should be transposed when computing. * @param transB If matrix B should be transposed when computing.
* @param bias The bias tensor. * @param bias The bias tensor.
* @param act The activation function. * @param act The activation function.
* @param computeType Specifies the data precision for the matrix multiply.
*/ */
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
bool transA = false, bool transB = false, Tensor bias = nullptr, bool transA = false, bool transB = false, Tensor bias = nullptr,
ActType act = ActType::None); ActType act = ActType::None, std::string computeType = "default");
OP_CLONE(MatmulObj); OP_CLONE(MatmulObj);
std::string toString() const override; std::string toString() const override;
@ -60,6 +64,7 @@ class MatmulObj : public OperatorObj {
int getN() const { return n; } int getN() const { return n; }
int getK() const { return k; } int getK() const { return k; }
auto getBMNK() const { return tuple{b, m, n, k}; } auto getBMNK() const { return tuple{b, m, n, k}; }
std::string getComputeType() const { return computeType; }
private: private:
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;

View File

@ -37,7 +37,13 @@ class OnnxStub:
It can be generated from an Onnx model object. It can be generated from an Onnx model object.
""" """
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False): def __init__(
self,
model: ModelProto,
runtime,
use_naive_allocator: bool = False,
matmul_compute_type: str = "default",
):
# We use some user-defined operators for distributed inference # We use some user-defined operators for distributed inference
try: try:
# onnx simplifier performs inplace simplify # onnx simplifier performs inplace simplify
@ -215,6 +221,7 @@ class OnnxStub:
False, False,
None, None,
backend.ActType.Linear, backend.ActType.Linear,
matmul_compute_type,
) )
elif node.op_type == "Gemm": elif node.op_type == "Gemm":
attributes = _parse_attribute( attributes = _parse_attribute(
@ -234,6 +241,7 @@ class OnnxStub:
transB == 1, transB == 1,
tensors[node.input[2]] if len(node.input) > 2 else None, tensors[node.input[2]] if len(node.input) > 2 else None,
backend.ActType.Linear, backend.ActType.Linear,
matmul_compute_type,
) )
elif node.op_type == "BatchNormalization": elif node.op_type == "BatchNormalization":
(input, mean, var, scale, bias) = ( (input, mean, var, scale, bias) = (

View File

@ -73,15 +73,17 @@ Tensor GraphHandlerObj::convTransposed2d(Tensor input, Tensor weight,
} }
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA, Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
bool transB, Tensor bias, ActType act) { bool transB, Tensor bias, ActType act,
std::string matmul_compute_type) {
if (y) { if (y) {
g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA, g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA,
transB, std::move(bias), act); transB, std::move(bias), act,
matmul_compute_type);
return y; return y;
} else { } else {
return g return g
->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB, ->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB,
std::move(bias), act) std::move(bias), act, matmul_compute_type)
->getOutput(); ->getOutput();
} }
} }

View File

@ -33,6 +33,36 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20, CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23, CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
}; };
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
if (cuDataType == CUDA_R_16F) {
return CUBLAS_COMPUTE_32F_FAST_16F;
} else if (cuDataType == CUDA_R_16BF) {
return CUBLAS_COMPUTE_32F_FAST_16BF;
} else if (cuDataType == CUDA_R_32F) {
return CUBLAS_COMPUTE_32F;
} else if (cuDataType == CUDA_R_64F) {
return CUBLAS_COMPUTE_64F;
} else {
IT_TODO_HALT();
}
}
cublasComputeType_t getCuComputeType(std::string computeTypeStr,
cudaDataType_t cuDataType) {
if (computeTypeStr == "tf32") {
return CUBLAS_COMPUTE_32F_FAST_TF32;
} else if (computeTypeStr == "bf16") {
return CUBLAS_COMPUTE_32F_FAST_16BF;
} else if (computeTypeStr == "fp16") {
return CUBLAS_COMPUTE_32F_FAST_16F;
} else if (computeTypeStr == "default") {
return cuDataType2ComputeType(cuDataType);
} else {
IT_TODO_HALT();
}
}
class matmulCublas : public Kernel { class matmulCublas : public Kernel {
bool do_compute(const Operator &_op, const PerfRecord &_record, bool do_compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const { const RuntimeObj *_context) const {
@ -78,6 +108,9 @@ class matmulCublas : public Kernel {
} }
// TODO:use compute type // TODO:use compute type
cublasStatus_t stat; cublasStatus_t stat;
std::string computeTypeStr = op->getComputeType();
auto cuComputeType = getCuComputeType(computeTypeStr, cuDataType);
if (b > 1) { if (b > 1) {
// Support batch broadcast with zero stride // Support batch broadcast with zero stride
int dimA = op->getInputs(0)->getRank(); int dimA = op->getInputs(0)->getRank();
@ -99,14 +132,14 @@ class matmulCublas : public Kernel {
context->cublasHandle(), opB, opA, n, m, k, &alpha_half, context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda, inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_half, outData, cuDataType, ldc, m * n, b, strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo); cuComputeType, (cublasGemmAlgo_t)record->algo);
} else { } else {
stat = cublasGemmStridedBatchedEx( stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive, context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda, inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b, strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo); cuComputeType, (cublasGemmAlgo_t)record->algo);
} }
} else { } else {
if (dataType == DataType::Float16) { if (dataType == DataType::Float16) {
@ -115,13 +148,13 @@ class matmulCublas : public Kernel {
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k, stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_half, inBData, cuDataType, ldb, &alpha_half, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_half, inAData, cuDataType, lda, &beta_half,
outData, cuDataType, ldc, cuDataType, outData, cuDataType, ldc, cuComputeType,
(cublasGemmAlgo_t)record->algo); (cublasGemmAlgo_t)record->algo);
} else { } else {
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k, stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_naive, inBData, cuDataType, ldb, &alpha_naive, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_naive, inAData, cuDataType, lda, &beta_naive,
outData, cuDataType, ldc, cuDataType, outData, cuDataType, ldc, cuComputeType,
(cublasGemmAlgo_t)record->algo); (cublasGemmAlgo_t)record->algo);
} }
} }

View File

@ -5,10 +5,11 @@
namespace infini { namespace infini {
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB, [[maybe_unused]] Tensor bias, ActType act) bool transB, [[maybe_unused]] Tensor bias, ActType act,
std::string computeType)
: OperatorObj(OpType::MatMul, : OperatorObj(OpType::MatMul,
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}), bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
transA(transA), transB(transB), act(act), b(1) { transA(transA), transB(transB), act(act), b(1), computeType(computeType) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
@ -17,7 +18,8 @@ string MatmulObj::toString() const {
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B") os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B")
<< ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid() << ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid()
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid() << ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])"; << ",bmnk=[" << b << "," << m << "," << n << "," << k << "])"
<< ",computeType=" << computeType;
return os.str(); return os.str();
} }