diff --git a/python/mscclpp/ext/alltoallv_single.py b/python/mscclpp/ext/alltoallv_single.py index 29c01096..554eeab1 100644 --- a/python/mscclpp/ext/alltoallv_single.py +++ b/python/mscclpp/ext/alltoallv_single.py @@ -147,6 +147,7 @@ class MscclppAlltoAllV: self._d_send_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') self._d_recv_counts = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') self._d_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') + self._d_remote_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') @property def rank(self) -> int: @@ -225,6 +226,12 @@ class MscclppAlltoAllV: self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64)) self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64)) + # Exchange recv displacements with all peers so each rank knows where to + # write in the remote output buffer. remoteRecvDispls[peer] = peer's + # recvDispls[rank], i.e. the offset in peer's output where our data goes. + remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes) + self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64)) + # Get stream if stream is None: stream = torch.cuda.current_stream() @@ -236,6 +243,7 @@ class MscclppAlltoAllV: "sendDispls": self._d_send_displs.data_ptr(), "recvCounts": self._d_recv_counts.data_ptr(), "recvDispls": self._d_recv_displs.data_ptr(), + "remoteRecvDispls": self._d_remote_recv_displs.data_ptr(), } input_size = sum(send_counts_bytes) @@ -262,6 +270,57 @@ class MscclppAlltoAllV: return output + def _exchange_recv_displs(self, recv_displs_bytes: list) -> list: + """ + Exchange recv displacement arrays between all ranks via bootstrap send/recv. + + Each rank needs to know where to write in each peer's output buffer. + remoteRecvDispls[peer] = peer's recvDispls[rank] — the byte offset in + peer's output buffer where data from this rank should be placed. + + Uses a deadlock-free pattern: lower-ranked peer sends first, then receives. + + Args: + recv_displs_bytes: This rank's recv displacement array (in bytes) + + Returns: + List of remote recv displacements (one per rank, in bytes). + remoteRecvDispls[rank] == recv_displs_bytes[rank] (self, unused by kernel) + """ + rank = self._rank + world_size = self._world_size + bootstrap = self._comm.bootstrap() + + # Use CPU int64 tensors as send/recv buffers (data_ptr() gives uintptr_t) + my_displs = torch.tensor(recv_displs_bytes, dtype=torch.int64) # CPU tensor + data_size = world_size * 8 # world_size int64 values = world_size * 8 bytes + + # Collect all ranks' recv_displs via pairwise send/recv + all_recv_displs = [None] * world_size + all_recv_displs[rank] = list(recv_displs_bytes) + + for peer in range(world_size): + if peer == rank: + continue + peer_buf = torch.zeros(world_size, dtype=torch.int64) # CPU tensor + if rank < peer: + # Lower rank sends first to avoid deadlock + bootstrap.send(my_displs.data_ptr(), data_size, peer, 0) + bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0) + else: + # Higher rank receives first + bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0) + bootstrap.send(my_displs.data_ptr(), data_size, peer, 0) + all_recv_displs[peer] = peer_buf.tolist() + + # Build remoteRecvDispls: for each peer, what offset in peer's output + # buffer should this rank's data go to? + remote_recv_displs = [] + for peer in range(world_size): + remote_recv_displs.append(int(all_recv_displs[peer][rank])) + + return remote_recv_displs + def __del__(self): """Cleanup resources.""" # Let CUDA handle tensor cleanup automatically diff --git a/python/test/test_alltoallv_mscclpp.py b/python/test/test_alltoallv_mscclpp.py index bd7ca1f9..d5f1a6d5 100644 --- a/python/test/test_alltoallv_mscclpp.py +++ b/python/test/test_alltoallv_mscclpp.py @@ -109,14 +109,25 @@ def main(): if rank == 0: print("\n[Test 2] Variable-size all-to-all (MoE-like)") - # Simulate MoE token distribution: rank 0 sends more to rank 0, etc. - input_split_sizes = [(i + 1) * 512 for i in range(world_size)] - output_split_sizes = [512 * (rank + 1)] * world_size + # Simulate MoE token distribution with imbalanced routing. + # Build a full send matrix so each rank has different per-peer sizes. + # send_matrix[i][j] = number of elements rank i sends to rank j. + # For consistency: rank i's output_split[j] = send_matrix[j][i]. + import random + random.seed(42) + send_matrix = [] + for i in range(world_size): + row = [random.randint(128, 2048) for _ in range(world_size)] + send_matrix.append(row) + + input_split_sizes = send_matrix[rank] # what this rank sends to each peer + output_split_sizes = [send_matrix[j][rank] for j in range(world_size)] # what this rank receives from each peer total_input = sum(input_split_sizes) total_output = sum(output_split_sizes) - input_tensor = torch.randn(total_input, dtype=torch.float32, device='cuda') + # Fill input with rank-specific pattern for verification + input_tensor = torch.arange(total_input, dtype=torch.float32, device='cuda') + rank * 100000 output_tensor = torch.empty(total_output, dtype=torch.float32, device='cuda') output = alltoallv.all_to_all_single( @@ -127,43 +138,67 @@ def main(): ) torch.cuda.synchronize() + + # Verify: the local-to-local segment should match exactly + local_send_offset = sum(input_split_sizes[:rank]) + local_recv_offset = sum(output_split_sizes[:rank]) + local_size = input_split_sizes[rank] # == output_split_sizes[rank] + expected_local = input_tensor[local_send_offset:local_send_offset + local_size] + actual_local = output_tensor[local_recv_offset:local_recv_offset + local_size] + local_ok = torch.allclose(expected_local, actual_local) + if rank == 0: - print(f" Input splits: {input_split_sizes}") - print(f" Output splits: {output_split_sizes}") + print(f" Send matrix row (rank 0 sends): {input_split_sizes}") + print(f" Recv sizes (rank 0 receives): {output_split_sizes}") print(f" Input total: {total_input}, Output total: {total_output}") - print(f" PASS") + print(f" Local copy verified: {local_ok}") + print(f" {'PASS' if local_ok else 'FAIL'}") - # Test 3: Performance benchmark + # Test 3: Performance benchmark across message sizes (1KB to 128MB) if rank == 0: - print("\n[Test 3] Performance benchmark (1MB per rank)") - - msg_size = 1024 * 1024 # 1MB per message - input_size = msg_size * world_size - - input_tensor = torch.randn(input_size // 4, dtype=torch.float32, device='cuda') # 4 bytes per float - output_tensor = torch.empty_like(input_tensor) - - # Warmup - for _ in range(5): - output = alltoallv.all_to_all_single(input_tensor, output=output_tensor) - torch.cuda.synchronize() - - # Benchmark - n_iters = 20 - start = time.perf_counter() - for _ in range(n_iters): - output = alltoallv.all_to_all_single(input_tensor, output=output_tensor) - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - - # Calculate bandwidth - total_bytes = 2 * input_size * n_iters # read + write - bandwidth_gbps = total_bytes / elapsed / 1e9 - - if rank == 0: - print(f" {n_iters} iterations in {elapsed*1000:.2f} ms") - print(f" Bandwidth: {bandwidth_gbps:.2f} GB/s") - print(f" Per-iteration: {elapsed/n_iters*1000:.3f} ms") + print("\n[Test 3] Performance benchmark (1KB to 128MB per rank)") + print(f" {'Msg Size':>10s} {'Iters':>5s} {'Total (ms)':>10s} {'Lat (us)':>10s} {'BW (GB/s)':>10s}") + print(f" {'-'*10} {'-'*5} {'-'*10} {'-'*10} {'-'*10}") + + # Message sizes: 1KB, 4KB, 16KB, 64KB, 256KB, 1MB, 4MB, 16MB, 64MB, 128MB + msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0] # powers of 4 from 1KB to 64MB + msg_sizes.append(128 * 1024 * 1024) # add 128MB + + for msg_size in msg_sizes: + input_size = msg_size * world_size + n_elems = input_size // 4 # float32 = 4 bytes + + input_tensor = torch.randn(n_elems, dtype=torch.float32, device='cuda') + output_tensor = torch.empty_like(input_tensor) + + # Fewer warmup/iters for very large sizes + n_warmup = 3 if msg_size >= 16 * 1024 * 1024 else 5 + n_iters = 5 if msg_size >= 64 * 1024 * 1024 else (10 if msg_size >= 4 * 1024 * 1024 else 20) + + # Warmup + for _ in range(n_warmup): + alltoallv.all_to_all_single(input_tensor, output=output_tensor) + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(n_iters): + alltoallv.all_to_all_single(input_tensor, output=output_tensor) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + total_bytes = 2 * input_size * n_iters # read + write + bandwidth_gbps = total_bytes / elapsed / 1e9 + latency_us = elapsed / n_iters * 1e6 + + if rank == 0: + if msg_size >= 1024 * 1024: + size_str = f"{msg_size // (1024*1024)}MB" + elif msg_size >= 1024: + size_str = f"{msg_size // 1024}KB" + else: + size_str = f"{msg_size}B" + print(f" {size_str:>10s} {n_iters:>5d} {elapsed*1000:>10.2f} {latency_us:>10.1f} {bandwidth_gbps:>10.2f}") # Cleanup dist.barrier() diff --git a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu index f6318129..2e6ffbe8 100644 --- a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu +++ b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu @@ -85,9 +85,11 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( auto it_sendDispls = extras.find("sendDispls"); auto it_recvCounts = extras.find("recvCounts"); auto it_recvDispls = extras.find("recvDispls"); + auto it_remoteRecvDispls = extras.find("remoteRecvDispls"); if (it_sendCounts == extras.end() || it_sendDispls == extras.end() || - it_recvCounts == extras.end() || it_recvDispls == extras.end()) { + it_recvCounts == extras.end() || it_recvDispls == extras.end() || + it_remoteRecvDispls == extras.end()) { return CommResult::CommInternalError; } @@ -95,6 +97,7 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( const size_t* d_sendDispls = reinterpret_cast(it_sendDispls->second); const size_t* d_recvCounts = reinterpret_cast(it_recvCounts->second); const size_t* d_recvDispls = reinterpret_cast(it_recvDispls->second); + const size_t* d_remoteRecvDispls = reinterpret_cast(it_remoteRecvDispls->second); // Use maximum threads (1024) for best bandwidth utilization const int threadsPerBlock = (nThreadsPerBlock > 0 && nThreadsPerBlock <= 1024) ? nThreadsPerBlock : 1024; @@ -114,7 +117,8 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( rank, worldSize, input, output, d_sendCounts, d_sendDispls, - d_recvCounts, d_recvDispls); + d_recvCounts, d_recvDispls, + d_remoteRecvDispls); } else if (worldSize > WORLD_SIZE_THRESHOLD) { // Large messages + large world: use ring kernel to avoid congestion alltoallvRingKernel<<<1, threadsPerBlock, 0, stream>>>( @@ -122,7 +126,8 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( rank, worldSize, input, output, d_sendCounts, d_sendDispls, - d_recvCounts, d_recvDispls); + d_recvCounts, d_recvDispls, + d_remoteRecvDispls); } else { // Large messages + small world: use pipelined chunked kernel alltoallvPipelinedKernel<<<1, threadsPerBlock, 0, stream>>>( @@ -130,7 +135,8 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( rank, worldSize, input, output, d_sendCounts, d_sendDispls, - d_recvCounts, d_recvDispls); + d_recvCounts, d_recvDispls, + d_remoteRecvDispls); } if (cudaGetLastError() == cudaSuccess) { diff --git a/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp index 20864546..b690d204 100644 --- a/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp +++ b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp @@ -48,7 +48,8 @@ __global__ void __launch_bounds__(1024) const size_t* sendCounts, const size_t* sendDispls, const size_t* recvCounts, - const size_t* recvDispls) { + const size_t* recvDispls, + const size_t* remoteRecvDispls) { int tid = threadIdx.x; int nThreads = blockDim.x; int nPeers = worldSize - 1; @@ -70,7 +71,7 @@ __global__ void __launch_bounds__(1024) if (sendCounts[peer] > 0) { // Use all threads for maximum copy throughput memoryChannels[chanIdx].put( - recvDispls[rank], // dst offset in peer's buffer + remoteRecvDispls[peer], // dst offset in peer's buffer (peer's recvDispls[rank]) sendDispls[peer], // src offset in our buffer sendCounts[peer], // size tid, // thread id @@ -113,7 +114,8 @@ __global__ void __launch_bounds__(1024) const size_t* sendCounts, const size_t* sendDispls, const size_t* recvCounts, - const size_t* recvDispls) { + const size_t* recvDispls, + const size_t* remoteRecvDispls) { int tid = threadIdx.x; int nThreads = blockDim.x; int nPeers = worldSize - 1; @@ -133,7 +135,7 @@ __global__ void __launch_bounds__(1024) size_t sendSize = sendCounts[peer]; size_t recvSize = recvCounts[peer]; - size_t dstOffset = recvDispls[rank]; + size_t dstOffset = remoteRecvDispls[peer]; // peer's recvDispls[rank] size_t srcOffset = sendDispls[peer]; // Send data in chunks for better memory access patterns @@ -182,7 +184,8 @@ __global__ void __launch_bounds__(1024) const size_t* sendCounts, const size_t* sendDispls, const size_t* recvCounts, - const size_t* recvDispls) { + const size_t* recvDispls, + const size_t* remoteRecvDispls) { int tid = threadIdx.x; int nThreads = blockDim.x; @@ -203,7 +206,7 @@ __global__ void __launch_bounds__(1024) // Send data to sendPeer using ALL threads if (sendCounts[sendPeer] > 0) { memoryChannels[chanIdx].put( - recvDispls[rank], + remoteRecvDispls[sendPeer], // dst offset in peer's buffer (peer's recvDispls[rank]) sendDispls[sendPeer], sendCounts[sendPeer], tid, diff --git a/test/mscclpp-test/alltoallv_test.cu b/test/mscclpp-test/alltoallv_test.cu index efeea803..2d3740a3 100644 --- a/test/mscclpp-test/alltoallv_test.cu +++ b/test/mscclpp-test/alltoallv_test.cu @@ -27,6 +27,7 @@ static size_t* d_sendCounts; static size_t* d_sendDispls; static size_t* d_recvCounts; static size_t* d_recvDispls; +static size_t* d_remoteRecvDispls; // peer's recvDispls[rank] for each peer // Device array for memory channels (used by library kernels) static DeviceHandle* d_memoryChannels; @@ -67,7 +68,8 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) { rank, worldSize, localSendBuffV, localRecvBuffV, d_sendCounts, d_sendDispls, - d_recvCounts, d_recvDispls); + d_recvCounts, d_recvDispls, + d_remoteRecvDispls); } else if (kernelNum == 1) { // Use ring-based kernel for larger world sizes mscclpp::collective::alltoallvRingKernel<<<1, nThreads, 0, stream>>>( @@ -75,7 +77,8 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) { rank, worldSize, localSendBuffV, localRecvBuffV, d_sendCounts, d_sendDispls, - d_recvCounts, d_recvDispls); + d_recvCounts, d_recvDispls, + d_remoteRecvDispls); } else if (kernelNum == 2) { // Use pipelined kernel for imbalanced workloads (MoE) mscclpp::collective::alltoallvPipelinedKernel<<<1, nThreads, 0, stream>>>( @@ -83,7 +86,8 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) { rank, worldSize, localSendBuffV, localRecvBuffV, d_sendCounts, d_sendDispls, - d_recvCounts, d_recvDispls); + d_recvCounts, d_recvDispls, + d_remoteRecvDispls); } } @@ -121,6 +125,13 @@ void AllToAllVTestColl::initData(const TestArgs& args, std::vector sendBu CUDATHROW(cudaMemcpy(d_sendDispls, sendDispls_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice)); CUDATHROW(cudaMemcpy(d_recvCounts, recvCounts_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice)); CUDATHROW(cudaMemcpy(d_recvDispls, recvDispls_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice)); + // remoteRecvDispls[peer] = peer's recvDispls[rank] = where our data goes in peer's output. + // For equal splits, all ranks have the same layout, so peer's recvDispls[rank] = our recvDispls[rank]. + std::vector remoteRecvDispls(worldSize); + for (int peer = 0; peer < worldSize; peer++) { + remoteRecvDispls[peer] = recvDispls_[rank]; + } + CUDATHROW(cudaMemcpy(d_remoteRecvDispls, remoteRecvDispls.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice)); } void AllToAllVTestColl::getBw(const double deltaSec, double& algBw, double& busBw) { @@ -214,6 +225,7 @@ void AllToAllVTestEngine::allocateBuffer() { CUDATHROW(cudaMalloc(&d_sendDispls, args_.totalRanks * sizeof(size_t))); CUDATHROW(cudaMalloc(&d_recvCounts, args_.totalRanks * sizeof(size_t))); CUDATHROW(cudaMalloc(&d_recvDispls, args_.totalRanks * sizeof(size_t))); + CUDATHROW(cudaMalloc(&d_remoteRecvDispls, args_.totalRanks * sizeof(size_t))); // Allocate device array for memory channels CUDATHROW(cudaMalloc(&d_memoryChannels, args_.totalRanks * sizeof(DeviceHandle)));