From 51942acee84cdb20002f8fdccf6be8c7fe9bd0d3 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 10 Jul 2024 11:32:36 +0800 Subject: [PATCH] fix #4731 --- src/llamafactory/train/tuner.py | 15 ++++++++++----- src/llamafactory/webui/components/export.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 99f2b660..cb55900f 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -75,18 +75,23 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: get_template_and_fix_tokenizer(tokenizer, data_args.template) model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab - if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None: + if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None: raise ValueError("Cannot merge adapters to a quantized model.") if not isinstance(model, PreTrainedModel): raise ValueError("The model is not a `PreTrainedModel`, export aborted.") - if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model - output_dtype = getattr(model.config, "torch_dtype", torch.float16) + if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type + setattr(model.config, "torch_dtype", torch.float16) + else: + if model_args.infer_dtype == "auto": + output_dtype = getattr(model.config, "torch_dtype", torch.float16) + else: + output_dtype = getattr(torch, model_args.infer_dtype) + setattr(model.config, "torch_dtype", output_dtype) model = model.to(output_dtype) - else: - setattr(model.config, "torch_dtype", torch.float16) + logger.info("Convert model dtype to: {}.".format(output_dtype)) model.save_pretrained( save_directory=model_args.export_dir, diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index 0a938f02..86fad2aa 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -48,7 +48,7 @@ def save_model( template: str, visual_inputs: bool, export_size: int, - export_quantization_bit: int, + export_quantization_bit: str, export_quantization_dataset: str, export_device: str, export_legacy_format: bool,