diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index ce455d62..92a17048 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -30,7 +30,8 @@ class GraphHandlerObj { int pw, int sh, int sw, int dh, int dw, int oph, int opw); Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, - Tensor bias, ActType act); + Tensor bias, ActType act, + std::string matmul_compute_type = "default"); Tensor batchNormalization(Tensor input, Tensor output, Tensor mean, Tensor var, Tensor scale, Tensor bias, float momentum, float eps, bool training); diff --git a/include/operators/matmul.h b/include/operators/matmul.h index 35a4c0a8..0c90c475 100644 --- a/include/operators/matmul.h +++ b/include/operators/matmul.h @@ -17,6 +17,9 @@ class MatmulObj : public OperatorObj { // Auxiliary attributes which are not a part of operator attributes. int b, m, n, k; + // Specifies the data precision for the matrix multiply. + std::string computeType = "default"; + public: /** * @brief Matmul operator with batch broadcast and tensor transpose @@ -38,10 +41,11 @@ class MatmulObj : public OperatorObj { * @param transB If matrix B should be transposed when computing. * @param bias The bias tensor. * @param act The activation function. + * @param computeType Specifies the data precision for the matrix multiply. */ MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA = false, bool transB = false, Tensor bias = nullptr, - ActType act = ActType::None); + ActType act = ActType::None, std::string computeType = "default"); OP_CLONE(MatmulObj); std::string toString() const override; @@ -60,6 +64,7 @@ class MatmulObj : public OperatorObj { int getN() const { return n; } int getK() const { return k; } auto getBMNK() const { return tuple{b, m, n, k}; } + std::string getComputeType() const { return computeType; } private: vector getWorkloadVector() const override; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index a21e0b0a..522a4813 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -37,7 +37,13 @@ class OnnxStub: It can be generated from an Onnx model object. """ - def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False): + def __init__( + self, + model: ModelProto, + runtime, + use_naive_allocator: bool = False, + matmul_compute_type: str = "default", + ): # We use some user-defined operators for distributed inference try: # onnx simplifier performs inplace simplify @@ -215,6 +221,7 @@ class OnnxStub: False, None, backend.ActType.Linear, + matmul_compute_type, ) elif node.op_type == "Gemm": attributes = _parse_attribute( @@ -234,6 +241,7 @@ class OnnxStub: transB == 1, tensors[node.input[2]] if len(node.input) > 2 else None, backend.ActType.Linear, + matmul_compute_type, ) elif node.op_type == "BatchNormalization": (input, mean, var, scale, bias) = ( @@ -618,7 +626,7 @@ class OnnxStub: keep_aspect_ratio_policy, nearest_mode, coordinate_transformation_mode, - ) + ) elif node.op_type == "Squeeze": axes = ( _parse_data(data[node.input[1]]) @@ -962,7 +970,7 @@ class OnnxStub: beta, bias, size, - ) + ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) @@ -1247,7 +1255,7 @@ class OnnxStub: axes, ) ) - ctx.push_node(make_node(ty.name, inputs, outputs, name)) + ctx.push_node(make_node(ty.name, inputs, outputs, name)) elif ty == backend.OpTypeId.Concat: axis = backend.concat_axis_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index c8458454..4e9fa0d3 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -73,15 +73,17 @@ Tensor GraphHandlerObj::convTransposed2d(Tensor input, Tensor weight, } Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA, - bool transB, Tensor bias, ActType act) { + bool transB, Tensor bias, ActType act, + std::string matmul_compute_type) { if (y) { g->addOpWithOutputs(std::move(a), std::move(b), y, transA, - transB, std::move(bias), act); + transB, std::move(bias), act, + matmul_compute_type); return y; } else { return g ->addOp(std::move(a), std::move(b), y, transA, transB, - std::move(bias), act) + std::move(bias), act, matmul_compute_type) ->getOutput(); } } diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index e2addde1..771cadb6 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -33,6 +33,36 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = { CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20, CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23, }; + +cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) { + if (cuDataType == CUDA_R_16F) { + return CUBLAS_COMPUTE_32F_FAST_16F; + } else if (cuDataType == CUDA_R_16BF) { + return CUBLAS_COMPUTE_32F_FAST_16BF; + } else if (cuDataType == CUDA_R_32F) { + return CUBLAS_COMPUTE_32F; + } else if (cuDataType == CUDA_R_64F) { + return CUBLAS_COMPUTE_64F; + } else { + IT_TODO_HALT(); + } +} + +cublasComputeType_t getCuComputeType(std::string computeTypeStr, + cudaDataType_t cuDataType) { + if (computeTypeStr == "tf32") { + return CUBLAS_COMPUTE_32F_FAST_TF32; + } else if (computeTypeStr == "bf16") { + return CUBLAS_COMPUTE_32F_FAST_16BF; + } else if (computeTypeStr == "fp16") { + return CUBLAS_COMPUTE_32F_FAST_16F; + } else if (computeTypeStr == "default") { + return cuDataType2ComputeType(cuDataType); + } else { + IT_TODO_HALT(); + } +} + class matmulCublas : public Kernel { bool do_compute(const Operator &_op, const PerfRecord &_record, const RuntimeObj *_context) const { @@ -78,6 +108,9 @@ class matmulCublas : public Kernel { } // TODO:use compute type cublasStatus_t stat; + std::string computeTypeStr = op->getComputeType(); + auto cuComputeType = getCuComputeType(computeTypeStr, cuDataType); + if (b > 1) { // Support batch broadcast with zero stride int dimA = op->getInputs(0)->getRank(); @@ -99,14 +132,14 @@ class matmulCublas : public Kernel { context->cublasHandle(), opB, opA, n, m, k, &alpha_half, inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda, strideA, &beta_half, outData, cuDataType, ldc, m * n, b, - cuDataType, (cublasGemmAlgo_t)record->algo); + cuComputeType, (cublasGemmAlgo_t)record->algo); } else { stat = cublasGemmStridedBatchedEx( context->cublasHandle(), opB, opA, n, m, k, &alpha_naive, inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda, strideA, &beta_naive, outData, cuDataType, ldc, m * n, b, - cuDataType, (cublasGemmAlgo_t)record->algo); + cuComputeType, (cublasGemmAlgo_t)record->algo); } } else { if (dataType == DataType::Float16) { @@ -115,13 +148,13 @@ class matmulCublas : public Kernel { stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k, &alpha_half, inBData, cuDataType, ldb, inAData, cuDataType, lda, &beta_half, - outData, cuDataType, ldc, cuDataType, + outData, cuDataType, ldc, cuComputeType, (cublasGemmAlgo_t)record->algo); } else { stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k, &alpha_naive, inBData, cuDataType, ldb, inAData, cuDataType, lda, &beta_naive, - outData, cuDataType, ldc, cuDataType, + outData, cuDataType, ldc, cuComputeType, (cublasGemmAlgo_t)record->algo); } } diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index db4533a7..15f9f8fd 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -5,10 +5,11 @@ namespace infini { MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, - bool transB, [[maybe_unused]] Tensor bias, ActType act) + bool transB, [[maybe_unused]] Tensor bias, ActType act, + std::string computeType) : OperatorObj(OpType::MatMul, bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}), - transA(transA), transB(transB), act(act), b(1) { + transA(transA), transB(transB), act(act), b(1), computeType(computeType) { IT_ASSERT(checkValid(graph)); } @@ -17,7 +18,8 @@ string MatmulObj::toString() const { os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B") << ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid() << ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid() - << ",bmnk=[" << b << "," << m << "," << n << "," << k << "])"; + << ",bmnk=[" << b << "," << m << "," << n << "," << k << "])" + << ",computeType=" << computeType; return os.str(); }