forked from jiuyuan/InfiniTensor
fix: fix matmul fp16
This commit is contained in:
parent
a91ed84354
commit
935b465cf2
|
@ -49,14 +49,12 @@ class matmulCublas : public Kernel {
|
|||
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
|
||||
ldc = n;
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
float alpha_naive = 1.f, beta_naive = 0.f;
|
||||
auto dataType = op->getDType();
|
||||
auto cuDataType = cublasDataTypeConvert(dataType);
|
||||
IT_ASSERT(cuDataType != CUDA_R_8I, "matmul don't support int8 dtype.");
|
||||
if (op->numInputs() == 2) { // no bias
|
||||
beta = 0.f;
|
||||
} else { // broadcast bias to output
|
||||
beta = 1.f;
|
||||
if (op->numInputs() == 3) { // have bias
|
||||
beta_naive = 1.f;
|
||||
auto inC = op->getInputs(2);
|
||||
auto out = op->getOutput();
|
||||
SmallArray inputShape, outputShape;
|
||||
|
@ -94,16 +92,38 @@ class matmulCublas : public Kernel {
|
|||
(dimB == 3 && op->getInputs(1)->getDims()[0] == 1))
|
||||
? 0 // Broadcast the batch dimension if batch size is 1
|
||||
: n * k;
|
||||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
cuDataType, ldb, strideB, inAData, cuDataType, lda, strideA,
|
||||
&beta, outData, cuDataType, ldc, m * n, b, cuDataType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
if (dataType == DataType::Float16) {
|
||||
half alpha_half = static_cast<half>(alpha_naive);
|
||||
half beta_half = static_cast<half>(beta_naive);
|
||||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
|
||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
|
||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
||||
|
||||
} else {
|
||||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
|
||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
|
||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
} else {
|
||||
stat = cublasGemmEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
cuDataType, ldb, inAData, cuDataType, lda, &beta, outData,
|
||||
cuDataType, ldc, cuDataType, (cublasGemmAlgo_t)record->algo);
|
||||
if (dataType == DataType::Float16) {
|
||||
half alpha_half = static_cast<half>(alpha_naive);
|
||||
half beta_half = static_cast<half>(beta_naive);
|
||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha_half, inBData, cuDataType, ldb,
|
||||
inAData, cuDataType, lda, &beta_half,
|
||||
outData, cuDataType, ldc, cuDataType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
} else {
|
||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha_naive, inBData, cuDataType, ldb,
|
||||
inAData, cuDataType, lda, &beta_naive,
|
||||
outData, cuDataType, ldc, cuDataType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
}
|
||||
// if (stat != CUBLAS_STATUS_SUCCESS)
|
||||
// cout << cublasGetErrorString(stat);
|
||||
|
|
Loading…
Reference in New Issue