fix oom issues in export

This commit is contained in:
hiyouga 2024-05-23 23:32:45 +08:00
parent 7134fb02bb
commit 67ebc7b388
2 changed files with 3 additions and 3 deletions

View File

@ -145,7 +145,7 @@ class ModelArguments:
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: str = field(
export_device: Literal["cpu", "cuda"] = field(
default="cpu",
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
)

View File

@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
if model_args.export_dir is not None:
model_args.device_map = {"": torch.device(model_args.export_device)}
if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")}
else:
model_args.device_map = "auto"