LLaMA-Factory-310P3/mindie/examples/convert/convert_utils.py

27 lines
1.2 KiB
Python

# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import json
import os.path
import shutil
from atb_llm.utils.file_utils import safe_open
def copy_tokenizer_files(model_dir, dest_dir):
os.makedirs(dest_dir, exist_ok=True)
for filename in os.listdir(model_dir):
if 'tokenizer' in filename or 'tokenization' in filename or 'special_token_map' in filename:
src_filepath = os.path.join(model_dir, filename)
dest_filepath = os.path.join(dest_dir, filename)
shutil.copyfile(src_filepath, dest_filepath)
def modify_config(model_dir, dest_dir, torch_dtype, quantize_type, kv_quant_type=False):
src_config_filepath = os.path.join(model_dir, 'config.json')
with open(src_config_filepath, 'r', encoding='utf-8') as fr:
data = json.load(fr)
data['torch_dtype'] = str(torch_dtype).split(".")[1]
data['quantize'] = quantize_type
if kv_quant_type:
data['kv_quant'] = "C8" # 当前仅支持kv cache仅支持C8类型的量化方式
dest_config_filtpath = os.path.join(dest_dir, 'config.json')
with safe_open(dest_config_filtpath, 'w', encoding='utf-8', is_exist_ok=False) as fw:
json.dump(data, fw, indent=4)