diff --git a/README.md b/README.md index b3544d8d..b82ae74d 100644 --- a/README.md +++ b/README.md @@ -369,8 +369,7 @@ python src/export_model.py \ --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ - --output_dir path_to_export \ - --fp16 + --output_dir path_to_export ``` ### API Demo diff --git a/README_zh.md b/README_zh.md index d0584258..e33bfc23 100644 --- a/README_zh.md +++ b/README_zh.md @@ -368,8 +368,7 @@ python src/export_model.py \ --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ - --output_dir path_to_export \ - --fp16 + --output_dir path_to_export ``` ### API 服务 diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index e5ebcb78..46d89cbf 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -9,10 +9,12 @@ from transformers.utils.versions import require_version from transformers.trainer_utils import get_last_checkpoint try: - from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available + from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available + is_fp16_available = is_torch_cuda_available() is_bf16_available = is_torch_bf16_gpu_available() is_npu_available = is_torch_npu_available() except ImportError: + is_fp16_available = torch.cuda.is_available() is_bf16_available = torch.cuda.is_bf16_supported() is_npu_available = False @@ -29,6 +31,17 @@ from llmtuner.hparams import ( logger = get_logger(__name__) +def _infer_dtype() -> torch.dtype: + if is_npu_available: + return torch.float16 + elif is_bf16_available: + return torch.bfloat16 + elif is_fp16_available: + return torch.float16 + else: + return torch.float32 + + def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: if args is not None: return parser.parse_dict(args) @@ -211,7 +224,7 @@ def get_train_args( elif training_args.fp16: model_args.compute_dtype = torch.float16 else: - model_args.compute_dtype = torch.float32 + model_args.compute_dtype = _infer_dtype() model_args.model_max_length = data_args.cutoff_len @@ -252,11 +265,6 @@ def get_infer_args( raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") # auto-detect cuda capability - if is_npu_available: - model_args.compute_dtype = torch.float16 - elif is_bf16_available: - model_args.compute_dtype = torch.bfloat16 - else: - model_args.compute_dtype = torch.float16 + model_args.compute_dtype = _infer_dtype() return model_args, data_args, finetuning_args, generating_args