diff --git a/python/test/test_alltoallv_mscclpp.py b/python/test/test_alltoallv_mscclpp.py index 2e55b98d..955ce4b9 100644 --- a/python/test/test_alltoallv_mscclpp.py +++ b/python/test/test_alltoallv_mscclpp.py @@ -227,6 +227,70 @@ def main(): size_str = f"{avg_msg_size}B" print(f" {size_str:>10s} {n_iters:>5d} {elapsed*1000:>10.2f} {latency_us:>10.1f} {bandwidth_gbps:>12.2f}") + # Test 4: torch.distributed.all_to_all_single baseline (same variable-size data) + if rank == 0: + print("\n[Test 4] torch.dist.all_to_all_single baseline (same variable sizes)") + print(f" {'Avg Size':>10s} {'Iters':>5s} {'Total (ms)':>10s} {'Lat (us)':>10s} {'algBW(GB/s)':>12s}") + print(f" {'-'*10} {'-'*5} {'-'*10} {'-'*10} {'-'*12}") + + for avg_msg_size in msg_sizes: + # Rebuild the same send_matrix (same seed → same data) + import random + random.seed(12345) + avg_elems = avg_msg_size // 4 + send_matrix = [] + for i in range(world_size): + row = [] + for j in range(world_size): + factor = 0.5 + random.random() + elems = max(1, int(avg_elems * factor)) + row.append(elems) + send_matrix.append(row) + + input_split_sizes = send_matrix[rank] + output_split_sizes = [send_matrix[j][rank] for j in range(world_size)] + + total_send = sum(input_split_sizes) + total_recv = sum(output_split_sizes) + + input_tensor = torch.randn(total_send, dtype=torch.float32, device='cuda') + output_tensor = torch.empty(total_recv, dtype=torch.float32, device='cuda') + + n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5 + n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20) + + # Warmup + for _ in range(n_warmup): + dist.all_to_all_single( + output_tensor, input_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes) + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(n_iters): + dist.all_to_all_single( + output_tensor, input_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + total_recv_bytes = total_recv * 4 + total_bytes = total_recv_bytes * n_iters + bandwidth_gbps = total_bytes / elapsed / 1e9 + latency_us = elapsed / n_iters * 1e6 + + if rank == 0: + if avg_msg_size >= 1024 * 1024: + size_str = f"{avg_msg_size // (1024*1024)}MB" + elif avg_msg_size >= 1024: + size_str = f"{avg_msg_size // 1024}KB" + else: + size_str = f"{avg_msg_size}B" + print(f" {size_str:>10s} {n_iters:>5d} {elapsed*1000:>10.2f} {latency_us:>10.1f} {bandwidth_gbps:>12.2f}") + # Cleanup dist.barrier() if rank == 0: