CPM-9G-8B/9G-Train/cpm/utils/exporter.py

131 lines
4.6 KiB
Python
Raw Normal View History

2024-02-27 14:33:33 +08:00
import functools
import json
import os
import shutil
import time
from typing import List
import bmtrain as bmt
import torch
from .log import logger
def rename_if_exists(file_path):
if not os.path.exists(file_path):
return
timestamp = time.strftime("%Y%m%d%H%M%S")
file_dir, file_name = os.path.split(file_path)
file_root, file_ext = os.path.splitext(file_name)
new_file_name = f"{file_root}_bak_{timestamp}{file_ext}"
new_file_path = os.path.join(file_dir, new_file_name)
try:
os.rename(file_path, new_file_path)
logger.info(f"File '{file_name}' already exists. Renamed to '{new_file_name}'")
except Exception as e:
logger.warn(
"rename file failed,file_path={file_path}, new_file_path={new_file_path},err={err}".format(
file_path=file_path, new_file_path=new_file_path, err=str(e)
)
)
def rename_if_exists_decorator(func):
@functools.wraps(func)
def wrapper(file_path, *args, **kwargs):
rename_if_exists(file_path)
return func(file_path, *args, **kwargs)
return wrapper
@rename_if_exists_decorator
def bmt_save(file_path: str, model: torch.nn.Module, export_files: List[str] = None):
bmt.save(model, file_path)
if export_files is not None:
export_files.append(file_path)
@rename_if_exists_decorator
def torch_save(file_path: str, obj: object, export_files: List[str] = None):
torch.save(obj, file_path)
if export_files is not None:
export_files.append(file_path)
@rename_if_exists_decorator
def json_save(file_path: str, obj: object, export_files: List[str] = None):
with open(file_path, "w") as data_f:
json.dump(obj, data_f)
if export_files is not None:
export_files.append(file_path)
def export(
model: torch.nn.Module, dataloader, optimizer: bmt.optim.AdamOffloadOptimizer, global_step, args, final_save=False
):
"""
一次 ckpt 保存
/{args.save}/
{save_name}-{global_step}.rank-0.opt
{save_name}-{global_step}.rank-n.opt
job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
config.json
vocabs.txt
{args.save_name}-{global_step}.pt
{args.save_name}-{global_step}.data
{args.save_name}-{global_step}.data.json
{args.save_name}-{global_step}.success
"""
export_model_dir = os.path.join(args.save, f"l_{global_step}")
os.makedirs(export_model_dir, exist_ok=True)
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
logger.info(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
export_files = []
# model checkpoint
bmt_save(
file_path=os.path.join(export_model_dir, base_file_name + ".pt"),
model=model,
export_files=export_files,
)
# opt is only used for continual pre-training, not the final opt
if not final_save:
grad_path = os.path.join(
args.save,
args.save_name + ("-%d.rank-%d.opt" % (global_step % (args.save_iters * 5), bmt.rank())),
)
torch.save(optimizer.state_dict(), grad_path)
logger.info(f"Successfully save grad file: {grad_path}")
all_states = dataloader.state_dict()
if bmt.rank() == 0:
# data checkpoint
# rank 0 writes the dataloader state
torch_save(
file_path=os.path.join(export_model_dir, base_file_name + ".data"),
obj=all_states,
export_files=export_files,
)
# data checkpoint json
# rank 0 writes the dataloader state into the json file
data_p_json = {k: v for k, v in all_states.items()}
for k in data_p_json:
data_p_json[k] = {k_of_v: data_p_json[k][k_of_v].tolist() for k_of_v in data_p_json[k]}
json_save(
file_path=os.path.join(export_model_dir, base_file_name + ".data.json"),
obj=data_p_json,
export_files=export_files,
)
# config 和 vocabs 和模型文件一起存储
model_cfg_path = os.path.join(export_model_dir, "config.json")
model_vocab_path = os.path.join(export_model_dir, "vocabs.txt")
export_files.extend([model_cfg_path, model_vocab_path])
shutil.copy(args.model_config, model_cfg_path)
shutil.copy(args.vocab, model_vocab_path)
logger.info(f"Successfully save model files! {export_files}")
del all_states
return export_model_dir