diff --git a/python/mscclpp/ext/alltoallv_single.py b/python/mscclpp/ext/alltoallv_single.py index c0d7d5a5..749bdf01 100644 --- a/python/mscclpp/ext/alltoallv_single.py +++ b/python/mscclpp/ext/alltoallv_single.py @@ -317,13 +317,14 @@ class MscclppAlltoAllV: def _exchange_recv_displs(self, recv_displs_bytes: list) -> list: """ - Exchange recv displacement arrays between all ranks via bootstrap send/recv. + Exchange recv displacement arrays between all ranks via bootstrap allGather. Each rank needs to know where to write in each peer's output buffer. remoteRecvDispls[peer] = peer's recvDispls[rank] — the byte offset in peer's output buffer where data from this rank should be placed. - Uses a deadlock-free pattern: lower-ranked peer sends first, then receives. + Uses bootstrap.all_gather() (ring sockets, pre-established during + initialize()) instead of pairwise TCP send/recv to avoid deadlocks. Args: recv_displs_bytes: This rank's recv displacement array (in bytes) @@ -332,38 +333,19 @@ class MscclppAlltoAllV: List of remote recv displacements (one per rank, in bytes). remoteRecvDispls[rank] == recv_displs_bytes[rank] (self, unused by kernel) """ + import numpy as np rank = self._rank world_size = self._world_size bootstrap = self._comm.bootstrap() - # Use CPU int64 tensors as send/recv buffers (data_ptr() gives uintptr_t) - my_displs = torch.tensor(recv_displs_bytes, dtype=torch.int64) # CPU tensor - data_size = world_size * 8 # world_size int64 values = world_size * 8 bytes - - # Collect all ranks' recv_displs via pairwise send/recv - all_recv_displs = [None] * world_size - all_recv_displs[rank] = list(recv_displs_bytes) - - for peer in range(world_size): - if peer == rank: - continue - peer_buf = torch.zeros(world_size, dtype=torch.int64) # CPU tensor - if rank < peer: - # Lower rank sends first to avoid deadlock - bootstrap.send(my_displs.data_ptr(), data_size, peer, 0) - bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0) - else: - # Higher rank receives first - bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0) - bootstrap.send(my_displs.data_ptr(), data_size, peer, 0) - all_recv_displs[peer] = peer_buf.tolist() - - # Build remoteRecvDispls: for each peer, what offset in peer's output - # buffer should this rank's data go to? - remote_recv_displs = [] - for peer in range(world_size): - remote_recv_displs.append(int(all_recv_displs[peer][rank])) + # All-gather: each rank contributes world_size int64 values + all_data = np.zeros((world_size, world_size), dtype=np.int64) + all_data[rank, :] = recv_displs_bytes + per_rank_bytes = world_size * 8 # world_size x sizeof(int64) + bootstrap.all_gather(all_data.ctypes.data, per_rank_bytes) + # remoteRecvDispls[peer] = peer's recv_displs[rank] + remote_recv_displs = [int(all_data[peer, rank]) for peer in range(world_size)] return remote_recv_displs def __del__(self):