fix(matmul): fix the data type conversion function for matmul.
This commit is contained in:
parent
068b7e51d6
commit
3c2a991db9
|
@ -36,9 +36,9 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
|||
|
||||
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
|
||||
if (cuDataType == CUDA_R_16F) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
} else if (cuDataType == CUDA_R_16BF) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
} else if (cuDataType == CUDA_R_32F) {
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
} else if (cuDataType == CUDA_R_64F) {
|
||||
|
@ -50,11 +50,11 @@ cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
|
|||
|
||||
cublasComputeType_t getCuComputeType(std::string computeTypeStr,
|
||||
cudaDataType_t cuDataType) {
|
||||
if (computeTypeStr == "tf32") {
|
||||
if (computeTypeStr == "tf32" && cuDataType == CUDA_R_32F) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
} else if (computeTypeStr == "bf16") {
|
||||
} else if (computeTypeStr == "bf16" && cuDataType == CUDA_R_32F) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
} else if (computeTypeStr == "fp16") {
|
||||
} else if (computeTypeStr == "fp16" && cuDataType == CUDA_R_32F) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
} else if (computeTypeStr == "default") {
|
||||
return cuDataType2ComputeType(cuDataType);
|
||||
|
|
Loading…
Reference in New Issue