mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Exchange recv displacement arrays between all ranks via bootstrap allGather
This commit is contained in:
@@ -317,13 +317,14 @@ class MscclppAlltoAllV:
|
||||
|
||||
def _exchange_recv_displs(self, recv_displs_bytes: list) -> list:
|
||||
"""
|
||||
Exchange recv displacement arrays between all ranks via bootstrap send/recv.
|
||||
Exchange recv displacement arrays between all ranks via bootstrap allGather.
|
||||
|
||||
Each rank needs to know where to write in each peer's output buffer.
|
||||
remoteRecvDispls[peer] = peer's recvDispls[rank] — the byte offset in
|
||||
peer's output buffer where data from this rank should be placed.
|
||||
|
||||
Uses a deadlock-free pattern: lower-ranked peer sends first, then receives.
|
||||
Uses bootstrap.all_gather() (ring sockets, pre-established during
|
||||
initialize()) instead of pairwise TCP send/recv to avoid deadlocks.
|
||||
|
||||
Args:
|
||||
recv_displs_bytes: This rank's recv displacement array (in bytes)
|
||||
@@ -332,38 +333,19 @@ class MscclppAlltoAllV:
|
||||
List of remote recv displacements (one per rank, in bytes).
|
||||
remoteRecvDispls[rank] == recv_displs_bytes[rank] (self, unused by kernel)
|
||||
"""
|
||||
import numpy as np
|
||||
rank = self._rank
|
||||
world_size = self._world_size
|
||||
bootstrap = self._comm.bootstrap()
|
||||
|
||||
# Use CPU int64 tensors as send/recv buffers (data_ptr() gives uintptr_t)
|
||||
my_displs = torch.tensor(recv_displs_bytes, dtype=torch.int64) # CPU tensor
|
||||
data_size = world_size * 8 # world_size int64 values = world_size * 8 bytes
|
||||
|
||||
# Collect all ranks' recv_displs via pairwise send/recv
|
||||
all_recv_displs = [None] * world_size
|
||||
all_recv_displs[rank] = list(recv_displs_bytes)
|
||||
|
||||
for peer in range(world_size):
|
||||
if peer == rank:
|
||||
continue
|
||||
peer_buf = torch.zeros(world_size, dtype=torch.int64) # CPU tensor
|
||||
if rank < peer:
|
||||
# Lower rank sends first to avoid deadlock
|
||||
bootstrap.send(my_displs.data_ptr(), data_size, peer, 0)
|
||||
bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0)
|
||||
else:
|
||||
# Higher rank receives first
|
||||
bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0)
|
||||
bootstrap.send(my_displs.data_ptr(), data_size, peer, 0)
|
||||
all_recv_displs[peer] = peer_buf.tolist()
|
||||
|
||||
# Build remoteRecvDispls: for each peer, what offset in peer's output
|
||||
# buffer should this rank's data go to?
|
||||
remote_recv_displs = []
|
||||
for peer in range(world_size):
|
||||
remote_recv_displs.append(int(all_recv_displs[peer][rank]))
|
||||
# All-gather: each rank contributes world_size int64 values
|
||||
all_data = np.zeros((world_size, world_size), dtype=np.int64)
|
||||
all_data[rank, :] = recv_displs_bytes
|
||||
per_rank_bytes = world_size * 8 # world_size x sizeof(int64)
|
||||
bootstrap.all_gather(all_data.ctypes.data, per_rank_bytes)
|
||||
|
||||
# remoteRecvDispls[peer] = peer's recv_displs[rank]
|
||||
remote_recv_displs = [int(all_data[peer, rank]) for peer in range(world_size)]
|
||||
return remote_recv_displs
|
||||
|
||||
def __del__(self):
|
||||
|
||||
Reference in New Issue
Block a user