fix(matmul): fix the data type conversion function for matmul.

This commit is contained in:
kilinchange 2024-08-07 14:28:23 +08:00
parent 068b7e51d6
commit 3c2a991db9
1 changed files with 5 additions and 5 deletions

View File

@ -36,9 +36,9 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) { cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
if (cuDataType == CUDA_R_16F) { if (cuDataType == CUDA_R_16F) {
return CUBLAS_COMPUTE_32F_FAST_16F; return CUBLAS_COMPUTE_16F;
} else if (cuDataType == CUDA_R_16BF) { } else if (cuDataType == CUDA_R_16BF) {
return CUBLAS_COMPUTE_32F_FAST_16BF; return CUBLAS_COMPUTE_16F;
} else if (cuDataType == CUDA_R_32F) { } else if (cuDataType == CUDA_R_32F) {
return CUBLAS_COMPUTE_32F; return CUBLAS_COMPUTE_32F;
} else if (cuDataType == CUDA_R_64F) { } else if (cuDataType == CUDA_R_64F) {
@ -50,11 +50,11 @@ cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
cublasComputeType_t getCuComputeType(std::string computeTypeStr, cublasComputeType_t getCuComputeType(std::string computeTypeStr,
cudaDataType_t cuDataType) { cudaDataType_t cuDataType) {
if (computeTypeStr == "tf32") { if (computeTypeStr == "tf32" && cuDataType == CUDA_R_32F) {
return CUBLAS_COMPUTE_32F_FAST_TF32; return CUBLAS_COMPUTE_32F_FAST_TF32;
} else if (computeTypeStr == "bf16") { } else if (computeTypeStr == "bf16" && cuDataType == CUDA_R_32F) {
return CUBLAS_COMPUTE_32F_FAST_16BF; return CUBLAS_COMPUTE_32F_FAST_16BF;
} else if (computeTypeStr == "fp16") { } else if (computeTypeStr == "fp16" && cuDataType == CUDA_R_32F) {
return CUBLAS_COMPUTE_32F_FAST_16F; return CUBLAS_COMPUTE_32F_FAST_16F;
} else if (computeTypeStr == "default") { } else if (computeTypeStr == "default") {
return cuDataType2ComputeType(cuDataType); return cuDataType2ComputeType(cuDataType);