From 13851fb04524e3a599b6c07d749f7463b8f75319 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sat, 11 May 2024 23:54:53 +0800 Subject: [PATCH] Update tuner.py --- src/llmtuner/train/tuner.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index 11509c20..cf44aa8c 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -15,9 +15,11 @@ from .pt import run_pt from .rm import run_rm from .sft import run_sft + if TYPE_CHECKING: from transformers import TrainerCallback + logger = get_logger(__name__) @@ -51,8 +53,8 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: raise ValueError("Please merge adapters before quantizing the model.") tokenizer_module = load_tokenizer(model_args) - tokenizer = tokenizer_module['tokenizer'] - processor = tokenizer_module['processor'] + tokenizer = tokenizer_module["tokenizer"] + processor = tokenizer_module["processor"] get_template_and_fix_tokenizer(tokenizer, data_args.template) model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab @@ -63,7 +65,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: raise ValueError("The model is not a `PreTrainedModel`, export aborted.") if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model - output_dtype = getattr(model.config, "torch_dtype", torch.float16) + output_dtype = torch.float16 setattr(model.config, "torch_dtype", output_dtype) model = model.to(output_dtype) @@ -86,10 +88,12 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: tokenizer.save_pretrained(model_args.export_dir) if model_args.export_hub_model_id is not None: tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) + + if model_args.visual_inputs and processor is not None: + getattr(processor, "image_processor").save_pretrained(model_args.export_dir) + if model_args.export_hub_model_id is not None: + getattr(processor, "image_processor").push_to_hub( + model_args.export_hub_model_id, token=model_args.hf_hub_token + ) except Exception: logger.warning("Cannot save tokenizer, please copy the files manually.") - - if model_args.visual_inputs: - processor.image_processor.save_pretrained(model_args.export_dir) - if model_args.export_hub_model_id is not None: - processor.image_processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) \ No newline at end of file