fix oom issues in export
This commit is contained in:
parent
7134fb02bb
commit
67ebc7b388
|
@ -145,7 +145,7 @@ class ModelArguments:
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||||
)
|
)
|
||||||
export_device: str = field(
|
export_device: Literal["cpu", "cuda"] = field(
|
||||||
default="cpu",
|
default="cpu",
|
||||||
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
|
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
|
||||||
)
|
)
|
||||||
|
|
|
@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args)
|
_check_extra_dependencies(model_args, finetuning_args)
|
||||||
|
|
||||||
if model_args.export_dir is not None:
|
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||||
model_args.device_map = {"": torch.device(model_args.export_device)}
|
model_args.device_map = {"": torch.device("cpu")}
|
||||||
else:
|
else:
|
||||||
model_args.device_map = "auto"
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue