feature: add parameter to config matmul compute type (#218)

* feature: add parameter to config matmul compute type

* fix format
This commit is contained in:
Chenjie Duan 2024-03-26 09:00:45 +08:00 committed by GitHub
parent 00e6cc2587
commit 54a35772fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 67 additions and 16 deletions

View File

@ -30,7 +30,8 @@ class GraphHandlerObj {
int pw, int sh, int sw, int dh, int dw, int oph,
int opw);
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
Tensor bias, ActType act);
Tensor bias, ActType act,
std::string matmul_compute_type = "default");
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
Tensor var, Tensor scale, Tensor bias,
float momentum, float eps, bool training);

View File

@ -17,6 +17,9 @@ class MatmulObj : public OperatorObj {
// Auxiliary attributes which are not a part of operator attributes.
int b, m, n, k;
// Specifies the data precision for the matrix multiply.
std::string computeType = "default";
public:
/**
* @brief Matmul operator with batch broadcast and tensor transpose
@ -38,10 +41,11 @@ class MatmulObj : public OperatorObj {
* @param transB If matrix B should be transposed when computing.
* @param bias The bias tensor.
* @param act The activation function.
* @param computeType Specifies the data precision for the matrix multiply.
*/
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
bool transA = false, bool transB = false, Tensor bias = nullptr,
ActType act = ActType::None);
ActType act = ActType::None, std::string computeType = "default");
OP_CLONE(MatmulObj);
std::string toString() const override;
@ -60,6 +64,7 @@ class MatmulObj : public OperatorObj {
int getN() const { return n; }
int getK() const { return k; }
auto getBMNK() const { return tuple{b, m, n, k}; }
std::string getComputeType() const { return computeType; }
private:
vector<int> getWorkloadVector() const override;

View File

@ -37,7 +37,13 @@ class OnnxStub:
It can be generated from an Onnx model object.
"""
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False):
def __init__(
self,
model: ModelProto,
runtime,
use_naive_allocator: bool = False,
matmul_compute_type: str = "default",
):
# We use some user-defined operators for distributed inference
try:
# onnx simplifier performs inplace simplify
@ -215,6 +221,7 @@ class OnnxStub:
False,
None,
backend.ActType.Linear,
matmul_compute_type,
)
elif node.op_type == "Gemm":
attributes = _parse_attribute(
@ -234,6 +241,7 @@ class OnnxStub:
transB == 1,
tensors[node.input[2]] if len(node.input) > 2 else None,
backend.ActType.Linear,
matmul_compute_type,
)
elif node.op_type == "BatchNormalization":
(input, mean, var, scale, bias) = (
@ -618,7 +626,7 @@ class OnnxStub:
keep_aspect_ratio_policy,
nearest_mode,
coordinate_transformation_mode,
)
)
elif node.op_type == "Squeeze":
axes = (
_parse_data(data[node.input[1]])
@ -962,7 +970,7 @@ class OnnxStub:
beta,
bias,
size,
)
)
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
@ -1247,7 +1255,7 @@ class OnnxStub:
axes,
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpTypeId.Concat:
axis = backend.concat_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))

View File

@ -73,15 +73,17 @@ Tensor GraphHandlerObj::convTransposed2d(Tensor input, Tensor weight,
}
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
bool transB, Tensor bias, ActType act) {
bool transB, Tensor bias, ActType act,
std::string matmul_compute_type) {
if (y) {
g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA,
transB, std::move(bias), act);
transB, std::move(bias), act,
matmul_compute_type);
return y;
} else {
return g
->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB,
std::move(bias), act)
std::move(bias), act, matmul_compute_type)
->getOutput();
}
}

View File

@ -33,6 +33,36 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
};
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
if (cuDataType == CUDA_R_16F) {
return CUBLAS_COMPUTE_32F_FAST_16F;
} else if (cuDataType == CUDA_R_16BF) {
return CUBLAS_COMPUTE_32F_FAST_16BF;
} else if (cuDataType == CUDA_R_32F) {
return CUBLAS_COMPUTE_32F;
} else if (cuDataType == CUDA_R_64F) {
return CUBLAS_COMPUTE_64F;
} else {
IT_TODO_HALT();
}
}
cublasComputeType_t getCuComputeType(std::string computeTypeStr,
cudaDataType_t cuDataType) {
if (computeTypeStr == "tf32") {
return CUBLAS_COMPUTE_32F_FAST_TF32;
} else if (computeTypeStr == "bf16") {
return CUBLAS_COMPUTE_32F_FAST_16BF;
} else if (computeTypeStr == "fp16") {
return CUBLAS_COMPUTE_32F_FAST_16F;
} else if (computeTypeStr == "default") {
return cuDataType2ComputeType(cuDataType);
} else {
IT_TODO_HALT();
}
}
class matmulCublas : public Kernel {
bool do_compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const {
@ -78,6 +108,9 @@ class matmulCublas : public Kernel {
}
// TODO:use compute type
cublasStatus_t stat;
std::string computeTypeStr = op->getComputeType();
auto cuComputeType = getCuComputeType(computeTypeStr, cuDataType);
if (b > 1) {
// Support batch broadcast with zero stride
int dimA = op->getInputs(0)->getRank();
@ -99,14 +132,14 @@ class matmulCublas : public Kernel {
context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo);
cuComputeType, (cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo);
cuComputeType, (cublasGemmAlgo_t)record->algo);
}
} else {
if (dataType == DataType::Float16) {
@ -115,13 +148,13 @@ class matmulCublas : public Kernel {
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_half, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_half,
outData, cuDataType, ldc, cuDataType,
outData, cuDataType, ldc, cuComputeType,
(cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_naive, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_naive,
outData, cuDataType, ldc, cuDataType,
outData, cuDataType, ldc, cuComputeType,
(cublasGemmAlgo_t)record->algo);
}
}

View File

@ -5,10 +5,11 @@
namespace infini {
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB, [[maybe_unused]] Tensor bias, ActType act)
bool transB, [[maybe_unused]] Tensor bias, ActType act,
std::string computeType)
: OperatorObj(OpType::MatMul,
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
transA(transA), transB(transB), act(act), b(1) {
transA(transA), transB(transB), act(act), b(1), computeType(computeType) {
IT_ASSERT(checkValid(graph));
}
@ -17,7 +18,8 @@ string MatmulObj::toString() const {
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B")
<< ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid()
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])";
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])"
<< ",computeType=" << computeType;
return os.str();
}