diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 99f2b660..cb55900f 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -75,18 +75,23 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: get_template_and_fix_tokenizer(tokenizer, data_args.template) model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab - if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None: + if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None: raise ValueError("Cannot merge adapters to a quantized model.") if not isinstance(model, PreTrainedModel): raise ValueError("The model is not a `PreTrainedModel`, export aborted.") - if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model - output_dtype = getattr(model.config, "torch_dtype", torch.float16) + if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type + setattr(model.config, "torch_dtype", torch.float16) + else: + if model_args.infer_dtype == "auto": + output_dtype = getattr(model.config, "torch_dtype", torch.float16) + else: + output_dtype = getattr(torch, model_args.infer_dtype) + setattr(model.config, "torch_dtype", output_dtype) model = model.to(output_dtype) - else: - setattr(model.config, "torch_dtype", torch.float16) + logger.info("Convert model dtype to: {}.".format(output_dtype)) model.save_pretrained( save_directory=model_args.export_dir, diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index 0a938f02..86fad2aa 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -48,7 +48,7 @@ def save_model( template: str, visual_inputs: bool, export_size: int, - export_quantization_bit: int, + export_quantization_bit: str, export_quantization_dataset: str, export_device: str, export_legacy_format: bool,