forked from jiuyuan/CPM-9G-8B
532 lines
20 KiB
Python
532 lines
20 KiB
Python
|
# COPIED from exporter, for scaling project special use.
|
|||
|
# Author: shengdinghu@gmail.com
|
|||
|
|
|||
|
import functools
|
|||
|
import gc
|
|||
|
import glob
|
|||
|
import json
|
|||
|
import multiprocessing as mp
|
|||
|
import os
|
|||
|
import re
|
|||
|
from copy import deepcopy
|
|||
|
from collections import defaultdict
|
|||
|
from itertools import chain
|
|||
|
import shutil
|
|||
|
import threading
|
|||
|
import time
|
|||
|
import hashlib
|
|||
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
from typing import List
|
|||
|
|
|||
|
import bmtrain as bmt
|
|||
|
from bmtrain.distributed import all_reduce, all_gather
|
|||
|
import torch
|
|||
|
|
|||
|
from fm9g.utils import allgather_objects
|
|||
|
from fm9g.utils.bitset import BitSet
|
|||
|
|
|||
|
from .log import logger
|
|||
|
|
|||
|
lock = threading.Lock()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def _save_artifacts(
|
|||
|
model_state, dataloader, tokenizer, opt_state, global_step, args, model_config, log_ckpt=None, final_save=False
|
|||
|
):
|
|||
|
"""
|
|||
|
Export model artifacts. Mainly for the purpose of asynchronous saving.
|
|||
|
"""
|
|||
|
try:
|
|||
|
# platform_cfg = get_platform_cfg()
|
|||
|
raise ValueError("No platform_cfg")
|
|||
|
except ValueError:
|
|||
|
platform_cfg = None
|
|||
|
export_model_dir = os.path.join(args.save, str(global_step))
|
|||
|
else:
|
|||
|
export_model_dir = (
|
|||
|
args.save if final_save else platform_cfg.gen_export_ckpt_dir_for_step(args.save, global_step)
|
|||
|
)
|
|||
|
os.makedirs(export_model_dir, exist_ok=True)
|
|||
|
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
|
|||
|
bmt.print_rank(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
|
|||
|
|
|||
|
# model checkpoint
|
|||
|
ckpt_path = os.path.join(export_model_dir, base_file_name + ".pt")
|
|||
|
# opt 文件仅用于项目内续训,不需要导出为模型版本文件
|
|||
|
opt_path = os.path.join(
|
|||
|
export_model_dir,
|
|||
|
args.save_name + ("-%d.rank-%d.opt" % (global_step, bmt.rank())),
|
|||
|
)
|
|||
|
|
|||
|
if bmt.rank() == 0:
|
|||
|
torch.save(model_state, ckpt_path)
|
|||
|
bmt.print_rank(f"Save checkpoint successfully, ckpt file path: {ckpt_path}")
|
|||
|
torch.save(opt_state, opt_path)
|
|||
|
print(f"Save optimizer state successfully, opt file path: {opt_path}")
|
|||
|
|
|||
|
del model_state
|
|||
|
del opt_state
|
|||
|
|
|||
|
# 保存统计量
|
|||
|
if log_ckpt is not None:
|
|||
|
bmt.print_rank("save log ckpt ...")
|
|||
|
with open(os.path.join(export_model_dir, base_file_name + ".log_ckpt"), "w") as log_ckpt_file:
|
|||
|
json.dump(log_ckpt, log_ckpt_file)
|
|||
|
|
|||
|
logger.info(f"Starting saving dataset state. ")
|
|||
|
|
|||
|
dataset_ckpt_path = os.path.join(export_model_dir, "dataset_ckpt")
|
|||
|
os.makedirs(dataset_ckpt_path, exist_ok=True)
|
|||
|
if bmt.config["tp_rank"] == 0:
|
|||
|
p_dataset = os.path.join(dataset_ckpt_path, f"dataset_{bmt.rank()}.data")
|
|||
|
dataloader.save_state_dict(p_dataset)
|
|||
|
|
|||
|
if bmt.rank() == 0:
|
|||
|
# config 和 vocabs 和模型文件一起存储
|
|||
|
model_config.save_pretrained(export_model_dir)
|
|||
|
try:
|
|||
|
tokenizer.save_pretrained(export_model_dir)
|
|||
|
except:
|
|||
|
bmt.print_rank("No save pretrained method for tokenizer")
|
|||
|
shutil.copy(args.tokenizer_path, export_model_dir)
|
|||
|
|
|||
|
# 存储完所有文件后调用
|
|||
|
if platform_cfg is not None:
|
|||
|
platform_cfg.finalize_model_save(export_model_dir, base_file_name)
|
|||
|
else:
|
|||
|
bmt.print_rank("No platform_cfg, skip finalize_model_save, may be not have .success file")
|
|||
|
logger.info(f"Successfully save model files: {os.listdir(export_model_dir)}")
|
|||
|
|
|||
|
# 每个进程都在export_model_dir写一个.save_done文件,用于判断是否所有进程都保存完毕
|
|||
|
# 直接用常规写文件方式
|
|||
|
os.makedirs(os.path.join(export_model_dir, "save_status"), exist_ok=True)
|
|||
|
with open(os.path.join(export_model_dir, f"save_status/{bmt.rank()}.save_done"), "w") as f:
|
|||
|
f.write("done")
|
|||
|
|
|||
|
# 等待所有进程都保存完毕, 不能用synchronize
|
|||
|
if bmt.rank() == 0:
|
|||
|
while True:
|
|||
|
if len(os.listdir(os.path.join(export_model_dir, "save_status"))) == bmt.world_size():
|
|||
|
break
|
|||
|
time.sleep(1)
|
|||
|
bmt.print_rank(f"All saved! Rank 0 Begin to merge dataset ckpt to {dataset_ckpt_path}/dataset_.data")
|
|||
|
merge_dataset_ckpts(export_model_dir, args.parallel_load_datastate//2)
|
|||
|
else:
|
|||
|
bmt.print_rank(f"rank-{bmt.rank()} done, wait for rank0 to merge dataset ckpt")
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def export(
|
|||
|
model: torch.nn.Module,
|
|||
|
dataloader,
|
|||
|
tokenizer,
|
|||
|
optimizer: bmt.optim.AdamOffloadOptimizer,
|
|||
|
global_step,
|
|||
|
args,
|
|||
|
log_ckpt=None,
|
|||
|
final_save=False,
|
|||
|
async_save=False,
|
|||
|
):
|
|||
|
"""
|
|||
|
一次 ckpt 保存:
|
|||
|
/{args.save}/
|
|||
|
├── job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时,job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
|
|||
|
├── config.json
|
|||
|
├── vocabs.txt
|
|||
|
├── {args.save_name}-{global_step}.rank-0.opt
|
|||
|
├── {args.save_name}-{global_step}.rank-n.opt
|
|||
|
├── {args.save_name}-{global_step}.pt
|
|||
|
├── {args.save_name}-{global_step}.data
|
|||
|
├── {args.save_name}-{global_step}.data.json
|
|||
|
├── {args.save_name}-{global_step}.success
|
|||
|
└── {args.save_name}-{global_step}.log_ckpt
|
|||
|
|
|||
|
"""
|
|||
|
|
|||
|
bmt.synchronize()
|
|||
|
model_state = bmt.store._save_to_rank0(model)
|
|||
|
opt_state = deepcopy(optimizer.state_dict())
|
|||
|
model_config = model.config
|
|||
|
|
|||
|
if async_save:
|
|||
|
# Save artifacts asynchronously
|
|||
|
save_proc = mp.Process(
|
|||
|
target=_save_artifacts,
|
|||
|
args=(model_state, dataloader, tokenizer, opt_state, global_step, args, model_config, log_ckpt, final_save),
|
|||
|
)
|
|||
|
save_proc.start()
|
|||
|
else:
|
|||
|
_save_artifacts(
|
|||
|
model_state, dataloader, tokenizer, opt_state, global_step, args, model_config, log_ckpt, final_save
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
def load_model_ckpt(args, model):
|
|||
|
"""args.load 是一个到/{args.save}/job_{job_id}_ckpt_{global_step}/ 的路径"""
|
|||
|
if args.load.endswith(".pt"):
|
|||
|
checkpoint_file = args.load
|
|||
|
else:
|
|||
|
# a directory
|
|||
|
load_path = args.load
|
|||
|
checkpoint_files = [file for file in os.listdir(load_path) if file.endswith(".pt")]
|
|||
|
assert len(checkpoint_files) == 1, "None or multiple .pt found in {}".format(load_path)
|
|||
|
checkpoint_file = os.path.join(load_path, checkpoint_files[0])
|
|||
|
bmt.print_rank("args.load is not None, start to load checkpoints from" + checkpoint_file)
|
|||
|
|
|||
|
bmt.load(model, checkpoint_file)
|
|||
|
return model
|
|||
|
|
|||
|
|
|||
|
def _legacy_load_optimizer_ckpt(args, optimizer):
|
|||
|
bmt.print_rank("Use legacy optimizer ckpt!")
|
|||
|
if args.load.endswith(".pt"):
|
|||
|
optimizer_path = os.path.dirname(os.path.dirname(args.load))
|
|||
|
else:
|
|||
|
optimizer_path = os.path.dirname(args.load)
|
|||
|
bmt.print_rank(os.listdir(optimizer_path))
|
|||
|
start = time.time()
|
|||
|
bmt.print_rank(
|
|||
|
"{}".format(
|
|||
|
sum(
|
|||
|
[
|
|||
|
1
|
|||
|
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
|||
|
else 0
|
|||
|
for i in os.listdir(optimizer_path)
|
|||
|
]
|
|||
|
)
|
|||
|
)
|
|||
|
)
|
|||
|
if (
|
|||
|
sum(
|
|||
|
[
|
|||
|
1
|
|||
|
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
|||
|
else 0
|
|||
|
for i in os.listdir(optimizer_path)
|
|||
|
]
|
|||
|
)
|
|||
|
== bmt.world_size()
|
|||
|
):
|
|||
|
pattern = "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank())
|
|||
|
bmt.print_rank("Will load opt that matches pattern: {}".format(pattern))
|
|||
|
for file_name in os.listdir(optimizer_path):
|
|||
|
if file_name.find(pattern) != -1:
|
|||
|
bmt.print_rank("start to load grad ckpt {}".format(file_name))
|
|||
|
states = torch.load(os.path.join(optimizer_path, file_name))
|
|||
|
optimizer.load_state_dict(states)
|
|||
|
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
|||
|
return optimizer
|
|||
|
|
|||
|
|
|||
|
def load_optimizer_ckpt(args, optimizer):
|
|||
|
if args.load.endswith(".pt"):
|
|||
|
optimizer_path = os.path.dirname(args.load)
|
|||
|
else:
|
|||
|
# a directory
|
|||
|
optimizer_path = args.load
|
|||
|
start = time.time()
|
|||
|
opt_num = sum(
|
|||
|
[1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(optimizer_path)]
|
|||
|
)
|
|||
|
|
|||
|
bmt.print_rank(f"Opt file num: {opt_num}")
|
|||
|
if opt_num == 0:
|
|||
|
return _legacy_load_optimizer_ckpt(args, optimizer)
|
|||
|
|
|||
|
if opt_num == 0:
|
|||
|
return _legacy_load_optimizer_ckpt(args, optimizer)
|
|||
|
|
|||
|
if opt_num == bmt.world_size():
|
|||
|
file_name = os.path.join(
|
|||
|
optimizer_path,
|
|||
|
args.save_name + "-{}.rank-{}.opt".format(args.start_step, bmt.rank()),
|
|||
|
)
|
|||
|
|
|||
|
if os.path.exists(file_name):
|
|||
|
print("rank {} start to load grad ckpt {}".format(bmt.rank(), file_name))
|
|||
|
states = torch.load(file_name)
|
|||
|
optimizer.load_state_dict(states)
|
|||
|
# optimizer_external_load_state_dict(optimizer, states, args.grad_ckpt_num)
|
|||
|
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
|||
|
return optimizer
|
|||
|
|
|||
|
|
|||
|
def load_dataloader_ckpt(args, mixed_dataset):
|
|||
|
load_success = _load_distributed_dataset_state(args, mixed_dataset)
|
|||
|
if not load_success:
|
|||
|
logger.info("load from distributed data state dict fail, try to load from single data state_dict")
|
|||
|
_load_dataloader_ckpt(args, mixed_dataset)
|
|||
|
|
|||
|
|
|||
|
def _load_dataloader_ckpt(args, mixed_dataset):
|
|||
|
"""args.load 是一个到/{args.save}/job_{job_id}_ckpt_{global_step}/ 的路径"""
|
|||
|
if args.load.endswith(".pt"):
|
|||
|
load_path = os.path.dirname(args.load)
|
|||
|
else:
|
|||
|
load_path = args.load
|
|||
|
dataset_states_path = [file for file in os.listdir(load_path) if file.endswith(".data")]
|
|||
|
assert len(dataset_states_path) == 1, "None or multiple .data found in {}, file list: {}".format(
|
|||
|
load_path, dataset_states_path
|
|||
|
)
|
|||
|
dataset_states_path = dataset_states_path[0]
|
|||
|
dataset_states_path = os.path.join(load_path, dataset_states_path)
|
|||
|
bmt.print_rank("args.load is not None, start to load data ckpt from " + dataset_states_path)
|
|||
|
dataset_states = torch.load(dataset_states_path)
|
|||
|
missing = mixed_dataset.load_state_dict(dataset_states)
|
|||
|
if len(missing) > 0:
|
|||
|
bmt.print_rank("missing keys when loading dataset states: ", missing)
|
|||
|
|
|||
|
|
|||
|
def load_trace_ckpt(args, dataloader):
|
|||
|
"""args.load 是一个到/{args.save}/job_{job_id}_ckpt_{global_step}/ 的路径"""
|
|||
|
return dataloader
|
|||
|
|
|||
|
|
|||
|
def load_log_ckpt(args):
|
|||
|
if args.load.endswith(".pt"):
|
|||
|
load_path = os.path.dirname(args.load)
|
|||
|
else:
|
|||
|
load_path = args.load
|
|||
|
log_ckpt_paths = [file for file in os.listdir(load_path) if file.endswith(".log_ckpt")]
|
|||
|
assert len(log_ckpt_paths) <= 1, "Multiple .data found in {}".format(load_path)
|
|||
|
if len(log_ckpt_paths) == 0:
|
|||
|
bmt.print_rank("No log ckpt is found in {}".format(load_path))
|
|||
|
return {}
|
|||
|
log_ckpt_path = os.path.join(load_path, log_ckpt_paths[0])
|
|||
|
bmt.print_rank("args.load is not None, start to load log ckpt from " + log_ckpt_path)
|
|||
|
with open(log_ckpt_path, "r") as log_ckpt_file:
|
|||
|
log_ckpt = json.load(log_ckpt_file)
|
|||
|
return log_ckpt
|
|||
|
|
|||
|
|
|||
|
def _load_distributed_dataset_state(args, mixed_dataset):
|
|||
|
rank = bmt.rank()
|
|||
|
logger.info(f"rank-{rank} -> [start]loading dataset states")
|
|||
|
if args.load.endswith(".pt"):
|
|||
|
load_dir = os.path.dirname(args.load)
|
|||
|
else:
|
|||
|
load_dir = args.load
|
|||
|
|
|||
|
p_datasets = sorted(glob.glob(os.path.join(load_dir, "dataset_ckpt/dataset_*.data")))
|
|||
|
if len(p_datasets) == 0: # 向后兼容
|
|||
|
bmt.print_rank("load_from_orginal_dataset_ckpt_folder")
|
|||
|
p_datasets = sorted(glob.glob(os.path.join(load_dir, "dataset_*.data")))
|
|||
|
|
|||
|
all_state_dict = dict()
|
|||
|
|
|||
|
def load_and_aggregate(p_dataset):
|
|||
|
"""Map func for loading and aggregating dataset states to all_state_dict"""
|
|||
|
|
|||
|
def new_key_init(all_state_dict, key, state_dict):
|
|||
|
all_state_dict[key] = {}
|
|||
|
for second_key in state_dict[key]:
|
|||
|
if second_key == "used":
|
|||
|
all_state_dict[key]["used"] = BitSet(1024)
|
|||
|
all_state_dict[key]["used"].update(state_dict[key]["used"])
|
|||
|
else:
|
|||
|
all_state_dict[key][second_key] = state_dict[key][second_key]
|
|||
|
|
|||
|
def load_state_dict_with_npy(p_dataset):
|
|||
|
state_dict = torch.load(p_dataset)
|
|||
|
for key in state_dict:
|
|||
|
if isinstance(state_dict[key]["used"], str):
|
|||
|
if os.path.basename(state_dict[key]["used"]) == state_dict[key]["used"]: # 如果只有文件名
|
|||
|
state_dict[key]["used"] = BitSet.load(os.path.join(load_dir, "dataset_ckpt", os.path.basename(state_dict[key]["used"])))
|
|||
|
else: # 如果是完整路径(向后兼容)
|
|||
|
state_dict[key]["used"] = BitSet.load(state_dict[key]["used"])
|
|||
|
return state_dict
|
|||
|
|
|||
|
print(f"Loading {p_dataset}...")
|
|||
|
state_dict = load_state_dict_with_npy(p_dataset)
|
|||
|
dataset_locks = {}
|
|||
|
for key in state_dict.keys():
|
|||
|
if key not in dataset_locks:
|
|||
|
with lock:
|
|||
|
if key not in dataset_locks:
|
|||
|
dataset_locks[key] = threading.Lock()
|
|||
|
|
|||
|
with dataset_locks[key]:
|
|||
|
if key in all_state_dict:
|
|||
|
if all_state_dict[key]["exhausted"]:
|
|||
|
continue
|
|||
|
elif state_dict[key]["exhausted"]:
|
|||
|
all_state_dict[key]["exhausted"] = True
|
|||
|
all_state_dict[key]["used"] = BitSet(1024)
|
|||
|
else:
|
|||
|
all_state_dict[key]["used"].update(state_dict[key]["used"])
|
|||
|
else:
|
|||
|
new_key_init(all_state_dict, key, state_dict)
|
|||
|
bmt.print_rank(f"[done]loaded dataset states: {p_dataset}")
|
|||
|
|
|||
|
if p_datasets:
|
|||
|
rank = bmt.rank()
|
|||
|
lst_time = time.time()
|
|||
|
if rank == 0:
|
|||
|
with ThreadPoolExecutor(max_workers=args.parallel_load_datastate) as executor:
|
|||
|
executor.map(load_and_aggregate, p_datasets)
|
|||
|
|
|||
|
logger.info(
|
|||
|
f"rank-{rank} -> load dataset from {len(p_datasets)} .data files. Time: {time.time() - lst_time:.2f}s"
|
|||
|
)
|
|||
|
# Broadcast the tensor to other process
|
|||
|
lst_time = time.time()
|
|||
|
all_state_dict = bmt.store.broadcast_object(all_state_dict, comm=bmt.config["comm"], src=0)
|
|||
|
logger.info(f"rank-{rank} -> broadcast dataset from rank-0 to other ranks. Time: {time.time() - lst_time:.2f}s")
|
|||
|
lst_time = time.time()
|
|||
|
|
|||
|
missing_keys = mixed_dataset.load_state_dict(all_state_dict)
|
|||
|
logger.info(f"rank-{rank} -> load mixed dataset state dict. Time: {time.time() - lst_time:.2f}s")
|
|||
|
if missing_keys:
|
|||
|
logger.info(
|
|||
|
f"rank-{rank} -> load dataset from {len(p_datasets)} .data files with {os.path.join(load_dir, 'dataset_ckpt', 'dataset*.pt')} : {p_datasets}, missing tasks: {missing_keys}"
|
|||
|
)
|
|||
|
else:
|
|||
|
state_info = {
|
|||
|
k: {
|
|||
|
"ave_tokens": s.get("ave_tokens", -1),
|
|||
|
"set_info": "mem:{}|density:{:.4f}|len:{}".format(
|
|||
|
s["used"].memory_usage(), s["used"].density(), len(s["used"])
|
|||
|
),
|
|||
|
}
|
|||
|
for k, s in all_state_dict.items()
|
|||
|
}
|
|||
|
logger.info(
|
|||
|
f"rank-{rank} -> load dataset from {len(p_datasets)} files with {os.path.join(load_dir, 'dataset*.pt')}. Info: {state_info}"
|
|||
|
)
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.info(f"No dataset*.data found. p_datasets: {p_datasets}")
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
import json
|
|||
|
import os
|
|||
|
from collections import OrderedDict
|
|||
|
|
|||
|
|
|||
|
def flatten_stats(stats, parent_key="", separator="/"):
|
|||
|
items = []
|
|||
|
for key, value in stats.items():
|
|||
|
new_key = f"{parent_key}{separator}{key}" if parent_key else key
|
|||
|
if isinstance(value, dict):
|
|||
|
items.extend(flatten_stats(value, new_key, separator).items())
|
|||
|
if isinstance(value, list):
|
|||
|
items.append((new_key, json.dumps(value)))
|
|||
|
else:
|
|||
|
items.append((new_key, value))
|
|||
|
return OrderedDict(items)
|
|||
|
|
|||
|
|
|||
|
def save_every_step_stats(stats, path):
|
|||
|
flattened_stats = flatten_stats(stats)
|
|||
|
|
|||
|
os.makedirs(os.path.join(path, "train_stats/"), exist_ok=True)
|
|||
|
|
|||
|
# Function to get the current file ID and size
|
|||
|
def get_current_file_id_and_size(path):
|
|||
|
id = 0
|
|||
|
while True:
|
|||
|
file_path = os.path.join(path, f"train_stats/{id}.jsonl")
|
|||
|
if not os.path.exists(file_path):
|
|||
|
return id, 0
|
|||
|
else:
|
|||
|
size = os.path.getsize(file_path)
|
|||
|
if size > 19 * 1024 * 1024: # Size in bytes (20 MB)
|
|||
|
id += 1
|
|||
|
else:
|
|||
|
return id, size
|
|||
|
|
|||
|
# Get the current file id and its size
|
|||
|
current_id, file_size = get_current_file_id_and_size(path)
|
|||
|
|
|||
|
# Generate the file path
|
|||
|
file_path = os.path.join(path, f"train_stats/{current_id}.jsonl")
|
|||
|
|
|||
|
# Write the flattened stats to the file
|
|||
|
with open(file_path, "a") as json_file:
|
|||
|
json_file.write(json.dumps(flattened_stats) + "\n")
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def merge_dataset_ckpts(load_dir, parallel_load_datastate):
|
|||
|
p_datasets = sorted(glob.glob(os.path.join(load_dir, "dataset_ckpt/dataset_*.data")))
|
|||
|
p_datasets= [x for x in p_datasets if "dataset_ckpt/dataset_.data" not in x]
|
|||
|
|
|||
|
bmt.print_rank(f"Files before merge (total num {len(p_datasets)}): {p_datasets}")
|
|||
|
|
|||
|
all_state_dict = dict()
|
|||
|
|
|||
|
def load_and_aggregate(p_dataset):
|
|||
|
"""Map func for loading and aggregating dataset states to all_state_dict"""
|
|||
|
|
|||
|
def new_key_init(all_state_dict, key, state_dict):
|
|||
|
all_state_dict[key] = {}
|
|||
|
for second_key in state_dict[key]:
|
|||
|
if second_key == "used":
|
|||
|
all_state_dict[key]["used"] = BitSet(1024)
|
|||
|
all_state_dict[key]["used"].update(state_dict[key]["used"])
|
|||
|
else:
|
|||
|
all_state_dict[key][second_key] = state_dict[key][second_key]
|
|||
|
|
|||
|
def load_state_dict_with_npy(p_dataset):
|
|||
|
state_dict = torch.load(p_dataset)
|
|||
|
for key in state_dict:
|
|||
|
if isinstance(state_dict[key]["used"], str):
|
|||
|
state_dict[key]["used"] = BitSet.load(os.path.join(load_dir, "dataset_ckpt", os.path.basename(state_dict[key]["used"])))
|
|||
|
return state_dict
|
|||
|
|
|||
|
print(f"Loading {p_dataset}...")
|
|||
|
state_dict = load_state_dict_with_npy(p_dataset)
|
|||
|
dataset_locks = {}
|
|||
|
|
|||
|
for key in state_dict.keys():
|
|||
|
if key not in dataset_locks:
|
|||
|
with lock:
|
|||
|
if key not in dataset_locks:
|
|||
|
dataset_locks[key] = threading.Lock()
|
|||
|
with dataset_locks[key]:
|
|||
|
if key in all_state_dict:
|
|||
|
if all_state_dict[key]["exhausted"]:
|
|||
|
continue
|
|||
|
elif state_dict[key]["exhausted"]:
|
|||
|
all_state_dict[key]["exhausted"] = True
|
|||
|
all_state_dict[key]["used"] = BitSet(1024)
|
|||
|
else:
|
|||
|
all_state_dict[key]["used"].update(state_dict[key]["used"])
|
|||
|
else:
|
|||
|
new_key_init(all_state_dict, key, state_dict)
|
|||
|
del state_dict
|
|||
|
print(f"Merged {p_dataset}...")
|
|||
|
|
|||
|
if p_datasets:
|
|||
|
# with ThreadPoolExecutor(max_workers=args.parallel_load_datastate) as executor:
|
|||
|
with ThreadPoolExecutor(max_workers=parallel_load_datastate) as executor: # smaller than normal load to avoid OOM
|
|||
|
|
|||
|
executor.map(load_and_aggregate, p_datasets)
|
|||
|
# load_and_aggregate(p_datasets[0])
|
|||
|
|
|||
|
# Broadcast the tensor to other process
|
|||
|
save_path = os.path.join(load_dir, "dataset_ckpt", "dataset_.data")
|
|||
|
|
|||
|
# save_path = os.path.join("dataset_.data")
|
|||
|
for key in all_state_dict:
|
|||
|
npy_path = all_state_dict[key]["used"].save(save_path)
|
|||
|
all_state_dict[key]["used"] = npy_path
|
|||
|
bmt.print_rank(f"All state_dict after merge {all_state_dict}")
|
|||
|
torch.save(all_state_dict, save_path)
|
|||
|
|
|||
|
# Find all files that match the pattern
|
|||
|
files_to_remove = glob.glob(os.path.join(load_dir, "dataset_ckpt", "dataset_*.data*"))
|
|||
|
|
|||
|
# Remove the files
|
|||
|
for file in files_to_remove:
|
|||
|
if "dataset_.data" not in file:
|
|||
|
os.remove(file)
|
|||
|
|
|||
|
files_after_merge = os.listdir(os.path.join(load_dir, "dataset_ckpt"))
|
|||
|
bmt.print_rank(f"Files after merge: {files_after_merge}")
|