This commit is contained in:
hiyouga 2024-07-10 11:32:36 +08:00
parent fb0c400116
commit 51942acee8
2 changed files with 11 additions and 6 deletions

View File

@ -75,18 +75,23 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
get_template_and_fix_tokenizer(tokenizer, data_args.template)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
raise ValueError("Cannot merge adapters to a quantized model.")
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type
setattr(model.config, "torch_dtype", torch.float16)
else:
if model_args.infer_dtype == "auto":
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
else:
output_dtype = getattr(torch, model_args.infer_dtype)
setattr(model.config, "torch_dtype", output_dtype)
model = model.to(output_dtype)
else:
setattr(model.config, "torch_dtype", torch.float16)
logger.info("Convert model dtype to: {}.".format(output_dtype))
model.save_pretrained(
save_directory=model_args.export_dir,

View File

@ -48,7 +48,7 @@ def save_model(
template: str,
visual_inputs: bool,
export_size: int,
export_quantization_bit: int,
export_quantization_bit: str,
export_quantization_dataset: str,
export_device: str,
export_legacy_format: bool,