forked from jiuyuan/CPM-9G-8B
30 lines
995 B
Python
30 lines
995 B
Python
|
import pickle
|
||
|
|
||
|
import bmtrain as bmt
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def allgather_objects(obj):
|
||
|
if bmt.world_size() == 1:
|
||
|
return [obj]
|
||
|
|
||
|
with torch.no_grad():
|
||
|
data_bytes: bytes = pickle.dumps(obj)
|
||
|
data_length: int = len(data_bytes)
|
||
|
|
||
|
gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long)
|
||
|
gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu()
|
||
|
max_data_length = gathered_length.max().item()
|
||
|
|
||
|
gpu_data_bytes = torch.zeros(max_data_length, dtype=torch.uint8, device="cuda")
|
||
|
byte_storage = torch.ByteStorage.from_buffer(data_bytes)
|
||
|
gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage)
|
||
|
|
||
|
gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu()
|
||
|
|
||
|
ret = []
|
||
|
for i in range(gathered_data.size(0)):
|
||
|
data_bytes = gathered_data[i, : gathered_length[i].item()].numpy().tobytes()
|
||
|
ret.append(pickle.loads(data_bytes))
|
||
|
return ret
|