forked from jiuyuan/InfiniTensor
feature: add parameter to config matmul compute type (#218)
* feature: add parameter to config matmul compute type * fix format
This commit is contained in:
parent
00e6cc2587
commit
54a35772fb
|
@ -30,7 +30,8 @@ class GraphHandlerObj {
|
||||||
int pw, int sh, int sw, int dh, int dw, int oph,
|
int pw, int sh, int sw, int dh, int dw, int oph,
|
||||||
int opw);
|
int opw);
|
||||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||||
Tensor bias, ActType act);
|
Tensor bias, ActType act,
|
||||||
|
std::string matmul_compute_type = "default");
|
||||||
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
|
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
|
||||||
Tensor var, Tensor scale, Tensor bias,
|
Tensor var, Tensor scale, Tensor bias,
|
||||||
float momentum, float eps, bool training);
|
float momentum, float eps, bool training);
|
||||||
|
|
|
@ -17,6 +17,9 @@ class MatmulObj : public OperatorObj {
|
||||||
// Auxiliary attributes which are not a part of operator attributes.
|
// Auxiliary attributes which are not a part of operator attributes.
|
||||||
int b, m, n, k;
|
int b, m, n, k;
|
||||||
|
|
||||||
|
// Specifies the data precision for the matrix multiply.
|
||||||
|
std::string computeType = "default";
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Matmul operator with batch broadcast and tensor transpose
|
* @brief Matmul operator with batch broadcast and tensor transpose
|
||||||
|
@ -38,10 +41,11 @@ class MatmulObj : public OperatorObj {
|
||||||
* @param transB If matrix B should be transposed when computing.
|
* @param transB If matrix B should be transposed when computing.
|
||||||
* @param bias The bias tensor.
|
* @param bias The bias tensor.
|
||||||
* @param act The activation function.
|
* @param act The activation function.
|
||||||
|
* @param computeType Specifies the data precision for the matrix multiply.
|
||||||
*/
|
*/
|
||||||
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
|
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
|
||||||
bool transA = false, bool transB = false, Tensor bias = nullptr,
|
bool transA = false, bool transB = false, Tensor bias = nullptr,
|
||||||
ActType act = ActType::None);
|
ActType act = ActType::None, std::string computeType = "default");
|
||||||
OP_CLONE(MatmulObj);
|
OP_CLONE(MatmulObj);
|
||||||
|
|
||||||
std::string toString() const override;
|
std::string toString() const override;
|
||||||
|
@ -60,6 +64,7 @@ class MatmulObj : public OperatorObj {
|
||||||
int getN() const { return n; }
|
int getN() const { return n; }
|
||||||
int getK() const { return k; }
|
int getK() const { return k; }
|
||||||
auto getBMNK() const { return tuple{b, m, n, k}; }
|
auto getBMNK() const { return tuple{b, m, n, k}; }
|
||||||
|
std::string getComputeType() const { return computeType; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
|
|
|
@ -37,7 +37,13 @@ class OnnxStub:
|
||||||
It can be generated from an Onnx model object.
|
It can be generated from an Onnx model object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: ModelProto,
|
||||||
|
runtime,
|
||||||
|
use_naive_allocator: bool = False,
|
||||||
|
matmul_compute_type: str = "default",
|
||||||
|
):
|
||||||
# We use some user-defined operators for distributed inference
|
# We use some user-defined operators for distributed inference
|
||||||
try:
|
try:
|
||||||
# onnx simplifier performs inplace simplify
|
# onnx simplifier performs inplace simplify
|
||||||
|
@ -215,6 +221,7 @@ class OnnxStub:
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
backend.ActType.Linear,
|
backend.ActType.Linear,
|
||||||
|
matmul_compute_type,
|
||||||
)
|
)
|
||||||
elif node.op_type == "Gemm":
|
elif node.op_type == "Gemm":
|
||||||
attributes = _parse_attribute(
|
attributes = _parse_attribute(
|
||||||
|
@ -234,6 +241,7 @@ class OnnxStub:
|
||||||
transB == 1,
|
transB == 1,
|
||||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||||
backend.ActType.Linear,
|
backend.ActType.Linear,
|
||||||
|
matmul_compute_type,
|
||||||
)
|
)
|
||||||
elif node.op_type == "BatchNormalization":
|
elif node.op_type == "BatchNormalization":
|
||||||
(input, mean, var, scale, bias) = (
|
(input, mean, var, scale, bias) = (
|
||||||
|
|
|
@ -73,15 +73,17 @@ Tensor GraphHandlerObj::convTransposed2d(Tensor input, Tensor weight,
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||||
bool transB, Tensor bias, ActType act) {
|
bool transB, Tensor bias, ActType act,
|
||||||
|
std::string matmul_compute_type) {
|
||||||
if (y) {
|
if (y) {
|
||||||
g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA,
|
g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA,
|
||||||
transB, std::move(bias), act);
|
transB, std::move(bias), act,
|
||||||
|
matmul_compute_type);
|
||||||
return y;
|
return y;
|
||||||
} else {
|
} else {
|
||||||
return g
|
return g
|
||||||
->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB,
|
->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB,
|
||||||
std::move(bias), act)
|
std::move(bias), act, matmul_compute_type)
|
||||||
->getOutput();
|
->getOutput();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,36 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
||||||
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
|
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
|
||||||
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
|
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
|
||||||
|
if (cuDataType == CUDA_R_16F) {
|
||||||
|
return CUBLAS_COMPUTE_32F_FAST_16F;
|
||||||
|
} else if (cuDataType == CUDA_R_16BF) {
|
||||||
|
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||||
|
} else if (cuDataType == CUDA_R_32F) {
|
||||||
|
return CUBLAS_COMPUTE_32F;
|
||||||
|
} else if (cuDataType == CUDA_R_64F) {
|
||||||
|
return CUBLAS_COMPUTE_64F;
|
||||||
|
} else {
|
||||||
|
IT_TODO_HALT();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasComputeType_t getCuComputeType(std::string computeTypeStr,
|
||||||
|
cudaDataType_t cuDataType) {
|
||||||
|
if (computeTypeStr == "tf32") {
|
||||||
|
return CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||||
|
} else if (computeTypeStr == "bf16") {
|
||||||
|
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||||
|
} else if (computeTypeStr == "fp16") {
|
||||||
|
return CUBLAS_COMPUTE_32F_FAST_16F;
|
||||||
|
} else if (computeTypeStr == "default") {
|
||||||
|
return cuDataType2ComputeType(cuDataType);
|
||||||
|
} else {
|
||||||
|
IT_TODO_HALT();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class matmulCublas : public Kernel {
|
class matmulCublas : public Kernel {
|
||||||
bool do_compute(const Operator &_op, const PerfRecord &_record,
|
bool do_compute(const Operator &_op, const PerfRecord &_record,
|
||||||
const RuntimeObj *_context) const {
|
const RuntimeObj *_context) const {
|
||||||
|
@ -78,6 +108,9 @@ class matmulCublas : public Kernel {
|
||||||
}
|
}
|
||||||
// TODO:use compute type
|
// TODO:use compute type
|
||||||
cublasStatus_t stat;
|
cublasStatus_t stat;
|
||||||
|
std::string computeTypeStr = op->getComputeType();
|
||||||
|
auto cuComputeType = getCuComputeType(computeTypeStr, cuDataType);
|
||||||
|
|
||||||
if (b > 1) {
|
if (b > 1) {
|
||||||
// Support batch broadcast with zero stride
|
// Support batch broadcast with zero stride
|
||||||
int dimA = op->getInputs(0)->getRank();
|
int dimA = op->getInputs(0)->getRank();
|
||||||
|
@ -99,14 +132,14 @@ class matmulCublas : public Kernel {
|
||||||
context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
|
context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
|
||||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||||
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
|
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
|
||||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
cuComputeType, (cublasGemmAlgo_t)record->algo);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
stat = cublasGemmStridedBatchedEx(
|
stat = cublasGemmStridedBatchedEx(
|
||||||
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
|
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
|
||||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||||
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
|
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
|
||||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
cuComputeType, (cublasGemmAlgo_t)record->algo);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (dataType == DataType::Float16) {
|
if (dataType == DataType::Float16) {
|
||||||
|
@ -115,13 +148,13 @@ class matmulCublas : public Kernel {
|
||||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||||
&alpha_half, inBData, cuDataType, ldb,
|
&alpha_half, inBData, cuDataType, ldb,
|
||||||
inAData, cuDataType, lda, &beta_half,
|
inAData, cuDataType, lda, &beta_half,
|
||||||
outData, cuDataType, ldc, cuDataType,
|
outData, cuDataType, ldc, cuComputeType,
|
||||||
(cublasGemmAlgo_t)record->algo);
|
(cublasGemmAlgo_t)record->algo);
|
||||||
} else {
|
} else {
|
||||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||||
&alpha_naive, inBData, cuDataType, ldb,
|
&alpha_naive, inBData, cuDataType, ldb,
|
||||||
inAData, cuDataType, lda, &beta_naive,
|
inAData, cuDataType, lda, &beta_naive,
|
||||||
outData, cuDataType, ldc, cuDataType,
|
outData, cuDataType, ldc, cuComputeType,
|
||||||
(cublasGemmAlgo_t)record->algo);
|
(cublasGemmAlgo_t)record->algo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,10 +5,11 @@
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
||||||
bool transB, [[maybe_unused]] Tensor bias, ActType act)
|
bool transB, [[maybe_unused]] Tensor bias, ActType act,
|
||||||
|
std::string computeType)
|
||||||
: OperatorObj(OpType::MatMul,
|
: OperatorObj(OpType::MatMul,
|
||||||
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
|
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
|
||||||
transA(transA), transB(transB), act(act), b(1) {
|
transA(transA), transB(transB), act(act), b(1), computeType(computeType) {
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,7 +18,8 @@ string MatmulObj::toString() const {
|
||||||
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B")
|
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B")
|
||||||
<< ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid()
|
<< ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid()
|
||||||
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
|
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
|
||||||
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])";
|
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])"
|
||||||
|
<< ",computeType=" << computeType;
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue