forked from p83651209/CPM-9G-8B
74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
import torch
|
|
import os
|
|
import sys
|
|
import shutil
|
|
import argparse
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--base_path", type=str, help="base model path", required=True)
|
|
parser.add_argument("--delta_path", type=str, help="the lora model path", required=True)
|
|
parser.add_argument("--merge_path", type=str, help="merge the base and lora model as one models", required=True)
|
|
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def merge_lora_models(args):
|
|
|
|
scale = 64
|
|
|
|
# model = torch.load(f"/home/wangxvjia/9g_models/llama_fin_new/checkpoints-epoch-4/llama-finance-sft-iteration-258-delta.pt")
|
|
model = torch.load(args.delta_path)
|
|
|
|
dic = {}
|
|
num = 0
|
|
allocated_mem = torch.cuda.memory_allocated()
|
|
print(f"allocated GPU memory: {allocated_mem/1024**3} GB")
|
|
for key, value in model.items():
|
|
print(key)
|
|
print(value.shape)
|
|
layer_list = key.split('.')
|
|
layer = ".".join(layer_list[:-1])
|
|
if layer in dic:
|
|
other = dic[layer].cuda()
|
|
value = value.cuda()
|
|
if layer_list[-1] == "lora_B":
|
|
other = torch.mm(value, other).cpu()
|
|
alpha = scale / value.shape[1]
|
|
else :
|
|
other = torch.mm(other, value).cpu()
|
|
dic.update({layer: other})
|
|
else:
|
|
dic.update({layer: value})
|
|
print("end")
|
|
print(f"alpha: {scale} | weight: {alpha}")
|
|
|
|
torch.cuda.empty_cache()
|
|
print("begin")
|
|
base_model = torch.load(args.base_path ,map_location=torch.device('cpu'))
|
|
# base_model = torch.load("/data/public/opensource_models/meta-llama/Llama-2-7b-mc/pytorch_model.pt",map_location=torch.device('cpu'))
|
|
|
|
for key, value in base_model.items():
|
|
layer_list = key.split('.')
|
|
layer = ".".join(layer_list[:-1]) + ".lora"
|
|
value = value.cuda()
|
|
if layer in dic:
|
|
print(layer)
|
|
other = dic[layer].cuda()
|
|
value = torch.add(value, alpha * other.half()).detach().cpu()
|
|
print(value)
|
|
value = value.cpu()
|
|
base_model.update({key: value})
|
|
|
|
# torch.save(base_model, f"/home/wangxvjia/9g_models/cpm_fin_new_1e4/fin/pytorch_model.pt")
|
|
torch.save(base_model, args.merge_path)
|
|
|
|
exit(0)
|
|
|
|
if __name__=='__main__':
|
|
|
|
args = get_args()
|
|
merge_lora_models(args) |