LLaMA-Factory-310P3/mindie/examples/models/chatglm/v2_6b/quant_utils.py

63 lines
2.4 KiB
Python

# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import os
import json
import shutil
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from atb_llm.utils.file_utils import safe_open
def parse_args():
parser = argparse.ArgumentParser(description="Creating quant weights for ChatGLM2-6B or ChatGLM3-6B")
parser.add_argument("--model_path", type=str, required=True, help="The path to model float weights")
parser.add_argument("--save_path", type=str, default="./quant_weight_glm", help="The path to save quant weights")
parser.add_argument("--dataset_path", type=str, required=True, help="The dataset path")
parser.add_argument("--group_size", type=int, default=128, help="The group size for w4a16")
return parser.parse_args()
def get_model_and_tokenizer(model_path):
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_path, torch_dtype=torch.float32,
trust_remote_code=True).cpu()
model.eval()
return model, tokenizer
def read_dataset(dataset_path):
calib_set = []
with safe_open(dataset_path, encoding='utf-8') as file:
for line in file:
calib_set.append(json.loads(line))
return calib_set
# 获取校准数据函数定义
def get_calib_dataset(tokenizer, calib_list, device="cpu"): # device="npu:0" 如果需要使用npu进行量化
calib_dataset = []
for calib_data in calib_list:
text = calib_data['inputs_pretokenized']
inputs = tokenizer([text], return_tensors='pt')
calib_dataset.append([
inputs.data['input_ids'].to(device),
inputs.data['position_ids'].to(device),
inputs.data['attention_mask'].to(device)
])
return calib_dataset
def copy_config_files(fp16_path, quant_path, quant_type):
model_files = [f for f in os.listdir(fp16_path) if f.startswith(("config", "tokeniz", "modeling_chatglm.py"))]
for f in model_files:
shutil.copy2(os.path.join(fp16_path, f), os.path.join(quant_path, f))
with safe_open(os.path.join(quant_path, "config.json"), 'r+', encoding='utf-8') as f:
config = json.load(f)
config['quantize'] = quant_type
f.seek(0)
json.dump(config, f, indent=4)
f.truncate()