27 lines
1.2 KiB
Python
27 lines
1.2 KiB
Python
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
|
import json
|
|
import os.path
|
|
import shutil
|
|
from atb_llm.utils.file_utils import safe_open
|
|
|
|
|
|
def copy_tokenizer_files(model_dir, dest_dir):
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|
for filename in os.listdir(model_dir):
|
|
if 'tokenizer' in filename or 'tokenization' in filename or 'special_token_map' in filename:
|
|
src_filepath = os.path.join(model_dir, filename)
|
|
dest_filepath = os.path.join(dest_dir, filename)
|
|
shutil.copyfile(src_filepath, dest_filepath)
|
|
|
|
|
|
def modify_config(model_dir, dest_dir, torch_dtype, quantize_type, kv_quant_type=False):
|
|
src_config_filepath = os.path.join(model_dir, 'config.json')
|
|
with open(src_config_filepath, 'r', encoding='utf-8') as fr:
|
|
data = json.load(fr)
|
|
data['torch_dtype'] = str(torch_dtype).split(".")[1]
|
|
data['quantize'] = quantize_type
|
|
if kv_quant_type:
|
|
data['kv_quant'] = "C8" # 当前仅支持kv cache仅支持C8类型的量化方式
|
|
dest_config_filtpath = os.path.join(dest_dir, 'config.json')
|
|
with safe_open(dest_config_filtpath, 'w', encoding='utf-8', is_exist_ok=False) as fw:
|
|
json.dump(data, fw, indent=4) |