CPM-9G-8B/9G-Train/cpm/utils/exporter.py

131 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import functools
import json
import os
import shutil
import time
from typing import List
import bmtrain as bmt
import torch
from .log import logger
def rename_if_exists(file_path):
if not os.path.exists(file_path):
return
timestamp = time.strftime("%Y%m%d%H%M%S")
file_dir, file_name = os.path.split(file_path)
file_root, file_ext = os.path.splitext(file_name)
new_file_name = f"{file_root}_bak_{timestamp}{file_ext}"
new_file_path = os.path.join(file_dir, new_file_name)
try:
os.rename(file_path, new_file_path)
logger.info(f"File '{file_name}' already exists. Renamed to '{new_file_name}'")
except Exception as e:
logger.warn(
"rename file failed,file_path={file_path}, new_file_path={new_file_path},err={err}".format(
file_path=file_path, new_file_path=new_file_path, err=str(e)
)
)
def rename_if_exists_decorator(func):
@functools.wraps(func)
def wrapper(file_path, *args, **kwargs):
rename_if_exists(file_path)
return func(file_path, *args, **kwargs)
return wrapper
@rename_if_exists_decorator
def bmt_save(file_path: str, model: torch.nn.Module, export_files: List[str] = None):
bmt.save(model, file_path)
if export_files is not None:
export_files.append(file_path)
@rename_if_exists_decorator
def torch_save(file_path: str, obj: object, export_files: List[str] = None):
torch.save(obj, file_path)
if export_files is not None:
export_files.append(file_path)
@rename_if_exists_decorator
def json_save(file_path: str, obj: object, export_files: List[str] = None):
with open(file_path, "w") as data_f:
json.dump(obj, data_f)
if export_files is not None:
export_files.append(file_path)
def export(
model: torch.nn.Module, dataloader, optimizer: bmt.optim.AdamOffloadOptimizer, global_step, args, final_save=False
):
"""
一次 ckpt 保存:
/{args.save}/
├── {save_name}-{global_step}.rank-0.opt
├── {save_name}-{global_step}.rank-n.opt
├── job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
├── config.json
├── vocabs.txt
├── {args.save_name}-{global_step}.pt
├── {args.save_name}-{global_step}.data
├── {args.save_name}-{global_step}.data.json
└── {args.save_name}-{global_step}.success
"""
export_model_dir = os.path.join(args.save, f"l_{global_step}")
os.makedirs(export_model_dir, exist_ok=True)
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
logger.info(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
export_files = []
# model checkpoint
bmt_save(
file_path=os.path.join(export_model_dir, base_file_name + ".pt"),
model=model,
export_files=export_files,
)
# opt is only used for continual pre-training, not the final opt
if not final_save:
grad_path = os.path.join(
args.save,
args.save_name + ("-%d.rank-%d.opt" % (global_step % (args.save_iters * 5), bmt.rank())),
)
torch.save(optimizer.state_dict(), grad_path)
logger.info(f"Successfully save grad file: {grad_path}")
all_states = dataloader.state_dict()
if bmt.rank() == 0:
# data checkpoint
# rank 0 writes the dataloader state
torch_save(
file_path=os.path.join(export_model_dir, base_file_name + ".data"),
obj=all_states,
export_files=export_files,
)
# data checkpoint json
# rank 0 writes the dataloader state into the json file
data_p_json = {k: v for k, v in all_states.items()}
for k in data_p_json:
data_p_json[k] = {k_of_v: data_p_json[k][k_of_v].tolist() for k_of_v in data_p_json[k]}
json_save(
file_path=os.path.join(export_model_dir, base_file_name + ".data.json"),
obj=data_p_json,
export_files=export_files,
)
# config 和 vocabs 和模型文件一起存储
model_cfg_path = os.path.join(export_model_dir, "config.json")
model_vocab_path = os.path.join(export_model_dir, "vocabs.txt")
export_files.extend([model_cfg_path, model_vocab_path])
shutil.copy(args.model_config, model_cfg_path)
shutil.copy(args.vocab, model_vocab_path)
logger.info(f"Successfully save model files! {export_files}")
del all_states
return export_model_dir