diff --git a/python/test/test_alltoallv_mscclpp.py b/python/test/test_alltoallv_mscclpp.py index cdd4397c..2e55b98d 100644 --- a/python/test/test_alltoallv_mscclpp.py +++ b/python/test/test_alltoallv_mscclpp.py @@ -154,54 +154,77 @@ def main(): print(f" Local copy verified: {local_ok}") print(f" {'PASS' if local_ok else 'FAIL'}") - # Test 3: Performance benchmark across message sizes (1KB to 128MB) + # Test 3: Performance benchmark with variable sizes (1KB to 128MB avg per peer) if rank == 0: - print("\n[Test 3] Performance benchmark (1KB to 128MB per rank)") - print(f" {'Msg Size':>10s} {'Iters':>5s} {'Total (ms)':>10s} {'Lat (us)':>10s} {'algBW(GB/s)':>12s}") + print("\n[Test 3] Variable-size performance benchmark (1KB to 128MB avg per peer)") + print(f" {'Avg Size':>10s} {'Iters':>5s} {'Total (ms)':>10s} {'Lat (us)':>10s} {'algBW(GB/s)':>12s}") print(f" {'-'*10} {'-'*5} {'-'*10} {'-'*10} {'-'*12}") - # Message sizes: 1KB, 4KB, 16KB, 64KB, 256KB, 1MB, 4MB, 16MB, 64MB, 128MB + # Message sizes: average bytes sent to each peer msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0] # powers of 4 from 1KB to 64MB msg_sizes.append(128 * 1024 * 1024) # add 128MB - for msg_size in msg_sizes: - input_size = msg_size * world_size - n_elems = input_size // 4 # float32 = 4 bytes + for avg_msg_size in msg_sizes: + # Build a variable send matrix: send_matrix[i][j] = bytes rank i sends to rank j. + # Use a deterministic seed so all ranks compute the same matrix. + # Sizes vary from 0.5× to 1.5× of avg_msg_size (in float32 elements). + import random + random.seed(12345) + avg_elems = avg_msg_size // 4 # float32 = 4 bytes + send_matrix = [] + for i in range(world_size): + row = [] + for j in range(world_size): + # Random factor between 0.5 and 1.5 + factor = 0.5 + random.random() + elems = max(1, int(avg_elems * factor)) + row.append(elems) + send_matrix.append(row) - input_tensor = torch.randn(n_elems, dtype=torch.float32, device='cuda') - output_tensor = torch.empty_like(input_tensor) + input_split_sizes = send_matrix[rank] + output_split_sizes = [send_matrix[j][rank] for j in range(world_size)] + + total_send = sum(input_split_sizes) + total_recv = sum(output_split_sizes) + + input_tensor = torch.randn(total_send, dtype=torch.float32, device='cuda') + output_tensor = torch.empty(total_recv, dtype=torch.float32, device='cuda') # Fewer warmup/iters for very large sizes - n_warmup = 3 if msg_size >= 16 * 1024 * 1024 else 5 - n_iters = 5 if msg_size >= 64 * 1024 * 1024 else (10 if msg_size >= 4 * 1024 * 1024 else 20) + n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5 + n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20) # Warmup for _ in range(n_warmup): - alltoallv.all_to_all_single(input_tensor, output=output_tensor) + alltoallv.all_to_all_single( + input_tensor, output=output_tensor, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes) torch.cuda.synchronize() # Benchmark start = time.perf_counter() for _ in range(n_iters): - alltoallv.all_to_all_single(input_tensor, output=output_tensor) + alltoallv.all_to_all_single( + input_tensor, output=output_tensor, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes) torch.cuda.synchronize() elapsed = time.perf_counter() - start - # Algorithm bandwidth: total data received per rank / time - # Each rank receives msg_size bytes from each of the other (world_size-1) peers - # plus msg_size from itself (local copy). Total recv = input_size = msg_size * world_size. - # Report unidirectional algBw (same convention as nccl-test). - total_bytes = input_size * n_iters + # Algorithm bandwidth: total bytes received per rank / time (unidirectional) + total_recv_bytes = total_recv * 4 # float32 + total_bytes = total_recv_bytes * n_iters bandwidth_gbps = total_bytes / elapsed / 1e9 latency_us = elapsed / n_iters * 1e6 if rank == 0: - if msg_size >= 1024 * 1024: - size_str = f"{msg_size // (1024*1024)}MB" - elif msg_size >= 1024: - size_str = f"{msg_size // 1024}KB" + if avg_msg_size >= 1024 * 1024: + size_str = f"{avg_msg_size // (1024*1024)}MB" + elif avg_msg_size >= 1024: + size_str = f"{avg_msg_size // 1024}KB" else: - size_str = f"{msg_size}B" + size_str = f"{avg_msg_size}B" print(f" {size_str:>10s} {n_iters:>5d} {elapsed*1000:>10.2f} {latency_us:>10.1f} {bandwidth_gbps:>12.2f}") # Cleanup