CPM-9G-8B/9G-Train/cpm/utils/object.py

30 lines
995 B
Python
Raw Permalink Normal View History

2024-02-27 14:33:33 +08:00
import pickle
import bmtrain as bmt
import torch
def allgather_objects(obj):
if bmt.world_size() == 1:
return [obj]
with torch.no_grad():
data_bytes: bytes = pickle.dumps(obj)
data_length: int = len(data_bytes)
gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long)
gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu()
max_data_length = gathered_length.max().item()
gpu_data_bytes = torch.zeros(max_data_length, dtype=torch.uint8, device="cuda")
byte_storage = torch.ByteStorage.from_buffer(data_bytes)
gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage)
gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu()
ret = []
for i in range(gathered_data.size(0)):
data_bytes = gathered_data[i, : gathered_length[i].item()].numpy().tobytes()
ret.append(pickle.loads(data_bytes))
return ret