Exchange recv displacement arrays between all ranks via bootstrap allGather

This commit is contained in:
Qinghua Zhou
2026-03-05 15:19:20 +00:00
parent d5743e2d6c
commit b7b180df24

View File

@@ -317,13 +317,14 @@ class MscclppAlltoAllV:
def _exchange_recv_displs(self, recv_displs_bytes: list) -> list:
"""
Exchange recv displacement arrays between all ranks via bootstrap send/recv.
Exchange recv displacement arrays between all ranks via bootstrap allGather.
Each rank needs to know where to write in each peer's output buffer.
remoteRecvDispls[peer] = peer's recvDispls[rank] — the byte offset in
peer's output buffer where data from this rank should be placed.
Uses a deadlock-free pattern: lower-ranked peer sends first, then receives.
Uses bootstrap.all_gather() (ring sockets, pre-established during
initialize()) instead of pairwise TCP send/recv to avoid deadlocks.
Args:
recv_displs_bytes: This rank's recv displacement array (in bytes)
@@ -332,38 +333,19 @@ class MscclppAlltoAllV:
List of remote recv displacements (one per rank, in bytes).
remoteRecvDispls[rank] == recv_displs_bytes[rank] (self, unused by kernel)
"""
import numpy as np
rank = self._rank
world_size = self._world_size
bootstrap = self._comm.bootstrap()
# Use CPU int64 tensors as send/recv buffers (data_ptr() gives uintptr_t)
my_displs = torch.tensor(recv_displs_bytes, dtype=torch.int64) # CPU tensor
data_size = world_size * 8 # world_size int64 values = world_size * 8 bytes
# Collect all ranks' recv_displs via pairwise send/recv
all_recv_displs = [None] * world_size
all_recv_displs[rank] = list(recv_displs_bytes)
for peer in range(world_size):
if peer == rank:
continue
peer_buf = torch.zeros(world_size, dtype=torch.int64) # CPU tensor
if rank < peer:
# Lower rank sends first to avoid deadlock
bootstrap.send(my_displs.data_ptr(), data_size, peer, 0)
bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0)
else:
# Higher rank receives first
bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0)
bootstrap.send(my_displs.data_ptr(), data_size, peer, 0)
all_recv_displs[peer] = peer_buf.tolist()
# Build remoteRecvDispls: for each peer, what offset in peer's output
# buffer should this rank's data go to?
remote_recv_displs = []
for peer in range(world_size):
remote_recv_displs.append(int(all_recv_displs[peer][rank]))
# All-gather: each rank contributes world_size int64 values
all_data = np.zeros((world_size, world_size), dtype=np.int64)
all_data[rank, :] = recv_displs_bytes
per_rank_bytes = world_size * 8 # world_size x sizeof(int64)
bootstrap.all_gather(all_data.ctypes.data, per_rank_bytes)
# remoteRecvDispls[peer] = peer's recv_displs[rank]
remote_recv_displs = [int(all_data[peer, rank]) for peer in range(world_size)]
return remote_recv_displs
def __del__(self):