forked from p04798526/LLaMA-Factory-Mirror
fix oom issues in export
This commit is contained in:
parent
7134fb02bb
commit
67ebc7b388
|
@ -145,7 +145,7 @@ class ModelArguments:
|
|||
default=1,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||
)
|
||||
export_device: str = field(
|
||||
export_device: Literal["cpu", "cuda"] = field(
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
|
||||
)
|
||||
|
|
|
@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
if model_args.export_dir is not None:
|
||||
model_args.device_map = {"": torch.device(model_args.export_device)}
|
||||
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||
model_args.device_map = {"": torch.device("cpu")}
|
||||
else:
|
||||
model_args.device_map = "auto"
|
||||
|
||||
|
|
Loading…
Reference in New Issue