41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
|
import argparse
|
|
|
|
from atb_llm.utils.convert import convert_files
|
|
from atb_llm.utils.hub import weight_files
|
|
from atb_llm.utils.log import logger
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_path', help="model and tokenizer path")
|
|
return parser.parse_args()
|
|
|
|
|
|
def convert_bin2st(model_path):
|
|
local_pt_files = weight_files(model_path, revision=None, extension=".bin")
|
|
local_st_files = [
|
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
|
for p in local_pt_files
|
|
]
|
|
convert_files(local_pt_files, local_st_files, discard_names=[])
|
|
_ = weight_files(model_path)
|
|
|
|
|
|
def convert_bin2st_from_pretrained(model_path):
|
|
from transformers import AutoModelForCausalLM
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
pretrained_model_name_or_path=model_path,
|
|
low_cpu_mem_usage=True,
|
|
torch_dtype="auto")
|
|
model.save_pretrained(model_path, safe_serialization=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_arguments()
|
|
|
|
try:
|
|
convert_bin2st(args.model_path)
|
|
except RuntimeError:
|
|
logger.warning('convert weights failed with torch.load method, need model loaded to convert')
|
|
convert_bin2st_from_pretrained(args.model_path) |