fix #4731
This commit is contained in:
parent
fb0c400116
commit
51942acee8
|
@ -75,18 +75,23 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||
get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
||||
|
||||
if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
|
||||
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
|
||||
raise ValueError("Cannot merge adapters to a quantized model.")
|
||||
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
|
||||
|
||||
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
|
||||
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||
if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type
|
||||
setattr(model.config, "torch_dtype", torch.float16)
|
||||
else:
|
||||
if model_args.infer_dtype == "auto":
|
||||
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||
else:
|
||||
output_dtype = getattr(torch, model_args.infer_dtype)
|
||||
|
||||
setattr(model.config, "torch_dtype", output_dtype)
|
||||
model = model.to(output_dtype)
|
||||
else:
|
||||
setattr(model.config, "torch_dtype", torch.float16)
|
||||
logger.info("Convert model dtype to: {}.".format(output_dtype))
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
|
|
|
@ -48,7 +48,7 @@ def save_model(
|
|||
template: str,
|
||||
visual_inputs: bool,
|
||||
export_size: int,
|
||||
export_quantization_bit: int,
|
||||
export_quantization_bit: str,
|
||||
export_quantization_dataset: str,
|
||||
export_device: str,
|
||||
export_legacy_format: bool,
|
||||
|
|
Loading…
Reference in New Issue