fix: fix matmul fp16

This commit is contained in:
zhangyunze 2023-12-29 16:55:16 +08:00
parent a91ed84354
commit 935b465cf2
1 changed files with 34 additions and 14 deletions

View File

@ -49,14 +49,12 @@ class matmulCublas : public Kernel {
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
ldc = n;
float alpha = 1.f, beta = 0.f;
float alpha_naive = 1.f, beta_naive = 0.f;
auto dataType = op->getDType();
auto cuDataType = cublasDataTypeConvert(dataType);
IT_ASSERT(cuDataType != CUDA_R_8I, "matmul don't support int8 dtype.");
if (op->numInputs() == 2) { // no bias
beta = 0.f;
} else { // broadcast bias to output
beta = 1.f;
if (op->numInputs() == 3) { // have bias
beta_naive = 1.f;
auto inC = op->getInputs(2);
auto out = op->getOutput();
SmallArray inputShape, outputShape;
@ -94,16 +92,38 @@ class matmulCublas : public Kernel {
(dimB == 3 && op->getInputs(1)->getDims()[0] == 1))
? 0 // Broadcast the batch dimension if batch size is 1
: n * k;
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
cuDataType, ldb, strideB, inAData, cuDataType, lda, strideA,
&beta, outData, cuDataType, ldc, m * n, b, cuDataType,
(cublasGemmAlgo_t)record->algo);
if (dataType == DataType::Float16) {
half alpha_half = static_cast<half>(alpha_naive);
half beta_half = static_cast<half>(beta_naive);
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo);
}
} else {
stat = cublasGemmEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
cuDataType, ldb, inAData, cuDataType, lda, &beta, outData,
cuDataType, ldc, cuDataType, (cublasGemmAlgo_t)record->algo);
if (dataType == DataType::Float16) {
half alpha_half = static_cast<half>(alpha_naive);
half beta_half = static_cast<half>(beta_naive);
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_half, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_half,
outData, cuDataType, ldc, cuDataType,
(cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_naive, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_naive,
outData, cuDataType, ldc, cuDataType,
(cublasGemmAlgo_t)record->algo);
}
}
// if (stat != CUBLAS_STATUS_SUCCESS)
// cout << cublasGetErrorString(stat);