diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 5885bb09..650d1c22 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -145,7 +145,7 @@ class ModelArguments: default=1, metadata={"help": "The file shard size (in GB) of the exported model."}, ) - export_device: str = field( + export_device: Literal["cpu", "cuda"] = field( default="cpu", metadata={"help": "The device used in model export, use cuda to avoid addmm errors."}, ) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 20f9a003..6311297e 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: _verify_model_args(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) - if model_args.export_dir is not None: - model_args.device_map = {"": torch.device(model_args.export_device)} + if model_args.export_dir is not None and model_args.export_device == "cpu": + model_args.device_map = {"": torch.device("cpu")} else: model_args.device_map = "auto"