LLaMA-Factory-310P3/mindie/examples/convert/convert_weights.py

41 lines
1.3 KiB
Python

# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import argparse
from atb_llm.utils.convert import convert_files
from atb_llm.utils.hub import weight_files
from atb_llm.utils.log import logger
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', help="model and tokenizer path")
return parser.parse_args()
def convert_bin2st(model_path):
local_pt_files = weight_files(model_path, revision=None, extension=".bin")
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
convert_files(local_pt_files, local_st_files, discard_names=[])
_ = weight_files(model_path)
def convert_bin2st_from_pretrained(model_path):
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path,
low_cpu_mem_usage=True,
torch_dtype="auto")
model.save_pretrained(model_path, safe_serialization=True)
if __name__ == '__main__':
args = parse_arguments()
try:
convert_bin2st(args.model_path)
except RuntimeError:
logger.warning('convert weights failed with torch.load method, need model loaded to convert')
convert_bin2st_from_pretrained(args.model_path)