From 42859f073434eab0928940e8a9c52f275a2fc93a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 16 Jan 2024 23:59:42 +0800 Subject: [PATCH] support export push_to_hub #2183 --- src/llmtuner/hparams/model_args.py | 4 ++++ src/llmtuner/train/tuner.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 36ff1e3f..356de716 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -106,6 +106,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."} ) + export_hub_model_id: Optional[str] = field( + default=None, + metadata={"help": "The name of the repository if push the model to the Hugging Face hub."} + ) def __post_init__(self): self.compute_dtype = None diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index c5100306..8705c98e 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -50,7 +50,7 @@ def export_model(args: Optional[Dict[str, Any]] = None): if not isinstance(model, PreTrainedModel): raise ValueError("The model is not a `PreTrainedModel`, export aborted.") - model.config.use_cache = True + setattr(model.config, "use_cache", True) if getattr(model.config, "torch_dtype", None) == "bfloat16": model = model.to(torch.bfloat16).to("cpu") else: @@ -62,11 +62,20 @@ def export_model(args: Optional[Dict[str, Any]] = None): max_shard_size="{}GB".format(model_args.export_size), safe_serialization=(not model_args.export_legacy_format) ) + if model_args.export_hub_model_id is not None: + model.push_to_hub( + model_args.export_hub_model_id, + token=model_args.hf_hub_token, + max_shard_size="{}GB".format(model_args.export_size), + safe_serialization=(not model_args.export_legacy_format) + ) try: tokenizer.padding_side = "left" # restore padding side tokenizer.init_kwargs["padding_side"] = "left" tokenizer.save_pretrained(model_args.export_dir) + if model_args.export_hub_model_id is not None: + tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) except: logger.warning("Cannot save tokenizer, please copy the files manually.")