forked from jiuyuan/InfiniTensor
feature: add parameter to config matmul compute type (#218)
* feature: add parameter to config matmul compute type * fix format
This commit is contained in:
parent
00e6cc2587
commit
54a35772fb
|
@ -30,7 +30,8 @@ class GraphHandlerObj {
|
|||
int pw, int sh, int sw, int dh, int dw, int oph,
|
||||
int opw);
|
||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||
Tensor bias, ActType act);
|
||||
Tensor bias, ActType act,
|
||||
std::string matmul_compute_type = "default");
|
||||
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
|
||||
Tensor var, Tensor scale, Tensor bias,
|
||||
float momentum, float eps, bool training);
|
||||
|
|
|
@ -17,6 +17,9 @@ class MatmulObj : public OperatorObj {
|
|||
// Auxiliary attributes which are not a part of operator attributes.
|
||||
int b, m, n, k;
|
||||
|
||||
// Specifies the data precision for the matrix multiply.
|
||||
std::string computeType = "default";
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Matmul operator with batch broadcast and tensor transpose
|
||||
|
@ -38,10 +41,11 @@ class MatmulObj : public OperatorObj {
|
|||
* @param transB If matrix B should be transposed when computing.
|
||||
* @param bias The bias tensor.
|
||||
* @param act The activation function.
|
||||
* @param computeType Specifies the data precision for the matrix multiply.
|
||||
*/
|
||||
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
|
||||
bool transA = false, bool transB = false, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
ActType act = ActType::None, std::string computeType = "default");
|
||||
OP_CLONE(MatmulObj);
|
||||
|
||||
std::string toString() const override;
|
||||
|
@ -60,6 +64,7 @@ class MatmulObj : public OperatorObj {
|
|||
int getN() const { return n; }
|
||||
int getK() const { return k; }
|
||||
auto getBMNK() const { return tuple{b, m, n, k}; }
|
||||
std::string getComputeType() const { return computeType; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
|
|
@ -37,7 +37,13 @@ class OnnxStub:
|
|||
It can be generated from an Onnx model object.
|
||||
"""
|
||||
|
||||
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
model: ModelProto,
|
||||
runtime,
|
||||
use_naive_allocator: bool = False,
|
||||
matmul_compute_type: str = "default",
|
||||
):
|
||||
# We use some user-defined operators for distributed inference
|
||||
try:
|
||||
# onnx simplifier performs inplace simplify
|
||||
|
@ -215,6 +221,7 @@ class OnnxStub:
|
|||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
matmul_compute_type,
|
||||
)
|
||||
elif node.op_type == "Gemm":
|
||||
attributes = _parse_attribute(
|
||||
|
@ -234,6 +241,7 @@ class OnnxStub:
|
|||
transB == 1,
|
||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||
backend.ActType.Linear,
|
||||
matmul_compute_type,
|
||||
)
|
||||
elif node.op_type == "BatchNormalization":
|
||||
(input, mean, var, scale, bias) = (
|
||||
|
|
|
@ -73,15 +73,17 @@ Tensor GraphHandlerObj::convTransposed2d(Tensor input, Tensor weight,
|
|||
}
|
||||
|
||||
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||
bool transB, Tensor bias, ActType act) {
|
||||
bool transB, Tensor bias, ActType act,
|
||||
std::string matmul_compute_type) {
|
||||
if (y) {
|
||||
g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA,
|
||||
transB, std::move(bias), act);
|
||||
transB, std::move(bias), act,
|
||||
matmul_compute_type);
|
||||
return y;
|
||||
} else {
|
||||
return g
|
||||
->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB,
|
||||
std::move(bias), act)
|
||||
std::move(bias), act, matmul_compute_type)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,36 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
|||
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
|
||||
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
|
||||
};
|
||||
|
||||
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
|
||||
if (cuDataType == CUDA_R_16F) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
} else if (cuDataType == CUDA_R_16BF) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
} else if (cuDataType == CUDA_R_32F) {
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
} else if (cuDataType == CUDA_R_64F) {
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
|
||||
cublasComputeType_t getCuComputeType(std::string computeTypeStr,
|
||||
cudaDataType_t cuDataType) {
|
||||
if (computeTypeStr == "tf32") {
|
||||
return CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
} else if (computeTypeStr == "bf16") {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
} else if (computeTypeStr == "fp16") {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
} else if (computeTypeStr == "default") {
|
||||
return cuDataType2ComputeType(cuDataType);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
|
||||
class matmulCublas : public Kernel {
|
||||
bool do_compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const {
|
||||
|
@ -78,6 +108,9 @@ class matmulCublas : public Kernel {
|
|||
}
|
||||
// TODO:use compute type
|
||||
cublasStatus_t stat;
|
||||
std::string computeTypeStr = op->getComputeType();
|
||||
auto cuComputeType = getCuComputeType(computeTypeStr, cuDataType);
|
||||
|
||||
if (b > 1) {
|
||||
// Support batch broadcast with zero stride
|
||||
int dimA = op->getInputs(0)->getRank();
|
||||
|
@ -99,14 +132,14 @@ class matmulCublas : public Kernel {
|
|||
context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
|
||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
|
||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
||||
cuComputeType, (cublasGemmAlgo_t)record->algo);
|
||||
|
||||
} else {
|
||||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
|
||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
|
||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
||||
cuComputeType, (cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
} else {
|
||||
if (dataType == DataType::Float16) {
|
||||
|
@ -115,13 +148,13 @@ class matmulCublas : public Kernel {
|
|||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha_half, inBData, cuDataType, ldb,
|
||||
inAData, cuDataType, lda, &beta_half,
|
||||
outData, cuDataType, ldc, cuDataType,
|
||||
outData, cuDataType, ldc, cuComputeType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
} else {
|
||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha_naive, inBData, cuDataType, ldb,
|
||||
inAData, cuDataType, lda, &beta_naive,
|
||||
outData, cuDataType, ldc, cuDataType,
|
||||
outData, cuDataType, ldc, cuComputeType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,10 +5,11 @@
|
|||
namespace infini {
|
||||
|
||||
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
||||
bool transB, [[maybe_unused]] Tensor bias, ActType act)
|
||||
bool transB, [[maybe_unused]] Tensor bias, ActType act,
|
||||
std::string computeType)
|
||||
: OperatorObj(OpType::MatMul,
|
||||
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
|
||||
transA(transA), transB(transB), act(act), b(1) {
|
||||
transA(transA), transB(transB), act(act), b(1), computeType(computeType) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -17,7 +18,8 @@ string MatmulObj::toString() const {
|
|||
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B")
|
||||
<< ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid()
|
||||
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
|
||||
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])";
|
||||
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])"
|
||||
<< ",computeType=" << computeType;
|
||||
return os.str();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue