131 lines
4.6 KiB
Python
131 lines
4.6 KiB
Python
import functools
|
||
import json
|
||
import os
|
||
import shutil
|
||
import time
|
||
from typing import List
|
||
|
||
import bmtrain as bmt
|
||
import torch
|
||
|
||
from .log import logger
|
||
|
||
|
||
def rename_if_exists(file_path):
|
||
if not os.path.exists(file_path):
|
||
return
|
||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||
file_dir, file_name = os.path.split(file_path)
|
||
file_root, file_ext = os.path.splitext(file_name)
|
||
new_file_name = f"{file_root}_bak_{timestamp}{file_ext}"
|
||
new_file_path = os.path.join(file_dir, new_file_name)
|
||
try:
|
||
os.rename(file_path, new_file_path)
|
||
logger.info(f"File '{file_name}' already exists. Renamed to '{new_file_name}'")
|
||
except Exception as e:
|
||
logger.warn(
|
||
"rename file failed,file_path={file_path}, new_file_path={new_file_path},err={err}".format(
|
||
file_path=file_path, new_file_path=new_file_path, err=str(e)
|
||
)
|
||
)
|
||
|
||
|
||
def rename_if_exists_decorator(func):
|
||
@functools.wraps(func)
|
||
def wrapper(file_path, *args, **kwargs):
|
||
rename_if_exists(file_path)
|
||
return func(file_path, *args, **kwargs)
|
||
|
||
return wrapper
|
||
|
||
|
||
@rename_if_exists_decorator
|
||
def bmt_save(file_path: str, model: torch.nn.Module, export_files: List[str] = None):
|
||
bmt.save(model, file_path)
|
||
if export_files is not None:
|
||
export_files.append(file_path)
|
||
|
||
|
||
@rename_if_exists_decorator
|
||
def torch_save(file_path: str, obj: object, export_files: List[str] = None):
|
||
torch.save(obj, file_path)
|
||
if export_files is not None:
|
||
export_files.append(file_path)
|
||
|
||
|
||
@rename_if_exists_decorator
|
||
def json_save(file_path: str, obj: object, export_files: List[str] = None):
|
||
with open(file_path, "w") as data_f:
|
||
json.dump(obj, data_f)
|
||
if export_files is not None:
|
||
export_files.append(file_path)
|
||
|
||
|
||
def export(
|
||
model: torch.nn.Module, dataloader, optimizer: bmt.optim.AdamOffloadOptimizer, global_step, args, final_save=False
|
||
):
|
||
"""
|
||
一次 ckpt 保存:
|
||
/{args.save}/
|
||
├── {save_name}-{global_step}.rank-0.opt
|
||
├── {save_name}-{global_step}.rank-n.opt
|
||
├── job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时,job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
|
||
├── config.json
|
||
├── vocabs.txt
|
||
├── {args.save_name}-{global_step}.pt
|
||
├── {args.save_name}-{global_step}.data
|
||
├── {args.save_name}-{global_step}.data.json
|
||
└── {args.save_name}-{global_step}.success
|
||
|
||
"""
|
||
export_model_dir = os.path.join(args.save, f"l_{global_step}")
|
||
os.makedirs(export_model_dir, exist_ok=True)
|
||
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
|
||
logger.info(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
|
||
export_files = []
|
||
|
||
# model checkpoint
|
||
bmt_save(
|
||
file_path=os.path.join(export_model_dir, base_file_name + ".pt"),
|
||
model=model,
|
||
export_files=export_files,
|
||
)
|
||
|
||
# opt is only used for continual pre-training, not the final opt
|
||
if not final_save:
|
||
grad_path = os.path.join(
|
||
args.save,
|
||
args.save_name + ("-%d.rank-%d.opt" % (global_step % (args.save_iters * 5), bmt.rank())),
|
||
)
|
||
torch.save(optimizer.state_dict(), grad_path)
|
||
logger.info(f"Successfully save grad file: {grad_path}")
|
||
|
||
all_states = dataloader.state_dict()
|
||
if bmt.rank() == 0:
|
||
# data checkpoint
|
||
# rank 0 writes the dataloader state
|
||
torch_save(
|
||
file_path=os.path.join(export_model_dir, base_file_name + ".data"),
|
||
obj=all_states,
|
||
export_files=export_files,
|
||
)
|
||
# data checkpoint json
|
||
# rank 0 writes the dataloader state into the json file
|
||
data_p_json = {k: v for k, v in all_states.items()}
|
||
for k in data_p_json:
|
||
data_p_json[k] = {k_of_v: data_p_json[k][k_of_v].tolist() for k_of_v in data_p_json[k]}
|
||
json_save(
|
||
file_path=os.path.join(export_model_dir, base_file_name + ".data.json"),
|
||
obj=data_p_json,
|
||
export_files=export_files,
|
||
)
|
||
# config 和 vocabs 和模型文件一起存储
|
||
model_cfg_path = os.path.join(export_model_dir, "config.json")
|
||
model_vocab_path = os.path.join(export_model_dir, "vocabs.txt")
|
||
export_files.extend([model_cfg_path, model_vocab_path])
|
||
shutil.copy(args.model_config, model_cfg_path)
|
||
shutil.copy(args.vocab, model_vocab_path)
|
||
logger.info(f"Successfully save model files! {export_files}")
|
||
del all_states
|
||
return export_model_dir
|