63 lines
2.4 KiB
Python
63 lines
2.4 KiB
Python
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
|
|
|
import os
|
|
import json
|
|
import shutil
|
|
import argparse
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
from atb_llm.utils.file_utils import safe_open
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Creating quant weights for ChatGLM2-6B or ChatGLM3-6B")
|
|
parser.add_argument("--model_path", type=str, required=True, help="The path to model float weights")
|
|
parser.add_argument("--save_path", type=str, default="./quant_weight_glm", help="The path to save quant weights")
|
|
parser.add_argument("--dataset_path", type=str, required=True, help="The dataset path")
|
|
parser.add_argument("--group_size", type=int, default=128, help="The group size for w4a16")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_model_and_tokenizer(model_path):
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_path, trust_remote_code=True)
|
|
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_path, torch_dtype=torch.float32,
|
|
trust_remote_code=True).cpu()
|
|
model.eval()
|
|
return model, tokenizer
|
|
|
|
|
|
def read_dataset(dataset_path):
|
|
calib_set = []
|
|
with safe_open(dataset_path, encoding='utf-8') as file:
|
|
for line in file:
|
|
calib_set.append(json.loads(line))
|
|
return calib_set
|
|
|
|
|
|
# 获取校准数据函数定义
|
|
def get_calib_dataset(tokenizer, calib_list, device="cpu"): # device="npu:0" 如果需要使用npu进行量化
|
|
calib_dataset = []
|
|
for calib_data in calib_list:
|
|
text = calib_data['inputs_pretokenized']
|
|
inputs = tokenizer([text], return_tensors='pt')
|
|
calib_dataset.append([
|
|
inputs.data['input_ids'].to(device),
|
|
inputs.data['position_ids'].to(device),
|
|
inputs.data['attention_mask'].to(device)
|
|
])
|
|
return calib_dataset
|
|
|
|
|
|
def copy_config_files(fp16_path, quant_path, quant_type):
|
|
model_files = [f for f in os.listdir(fp16_path) if f.startswith(("config", "tokeniz", "modeling_chatglm.py"))]
|
|
for f in model_files:
|
|
shutil.copy2(os.path.join(fp16_path, f), os.path.join(quant_path, f))
|
|
with safe_open(os.path.join(quant_path, "config.json"), 'r+', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
config['quantize'] = quant_type
|
|
f.seek(0)
|
|
json.dump(config, f, indent=4)
|
|
f.truncate()
|