Remove manully set use_cache; torch_dtype is not str, save model as bfloat16 used to fail;

This commit is contained in:
yhyu13 2024-01-21 11:12:15 +08:00
parent a0d59aa4ec
commit 9cdbd3bfc8
1 changed files with 4 additions and 5 deletions

View File

@ -56,12 +56,11 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
setattr(model.config, "use_cache", True)
if getattr(model.config, "torch_dtype", None) == "bfloat16":
model = model.to(torch.bfloat16).to("cpu")
if hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else:
model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", "float16")
model = model.to(torch.float32).to("cpu")
setattr(model.config, "torch_dtype", "float32")
model.save_pretrained(
save_directory=model_args.export_dir,