Merge pull request #2266 from yhyu13/fix_export_model_dtype

Remove manully set use_cache; torch_dtype is not str, save model as b…
This commit is contained in:
hoshi-hiyouga 2024-01-21 12:40:39 +08:00 committed by GitHub
commit ea6db72631
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 4 deletions

View File

@ -56,12 +56,11 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
setattr(model.config, "use_cache", True)
if getattr(model.config, "torch_dtype", None) == torch.bfloat16:
model = model.to(torch.bfloat16).to("cpu")
if hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else:
model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", "float16")
setattr(model.config, "torch_dtype", torch.float16)
model.save_pretrained(
save_directory=model_args.export_dir,