From 01a10e00de656bfe97072600602e78f1fa6deabd Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Sun, 10 May 2026 19:21:25 +0000 Subject: [PATCH] ext/ep: HT perf - lower lazy head-feedback threshold to chunk/4 Problem: kForwarderCoordinator only published min_head when min_head >= last_head + num_max_rdma_chunked_send_tokens. With 4096 tokens / 10 channels / 2 peers ~= 205 tokens per (channel,peer), the receive-buffer-space window advanced only every full chunk, and the last partial chunk never triggered an update, serializing handshakes and capping HT throughput. Fix: lower threshold to max(1, chunk/4); when any forwarder channel has retired, drop to 1 so partial tail chunks always publish. Result on 2-node GB200 NVL72 (cfg=20,8,256,16,128, 4096 tok / 7168 hidden): dispatch agg: 36 -> 78 GB/s (2.16x) combine agg: 45 -> 95 GB/s (2.11x) PASS, max diff = 0 Also strips Phase 4 diagnostic printfs. --- src/ext/ep/kernels/internode.cu | 62 ++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/src/ext/ep/kernels/internode.cu b/src/ext/ep/kernels/internode.cu index 70124e38..a5b3b165 100644 --- a/src/ext/ep/kernels/internode.cu +++ b/src/ext/ep/kernels/internode.cu @@ -292,23 +292,23 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co // mapping until the slot reaches `epoch * num_ranks` (everyone // arrived). Replaces the pairwise PortChannel signal/wait above. if (thread_id == 0) { - if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b0\n", rank, (unsigned long long)nvls_epoch); + if (rank == 0) (void)0; uint64_t* mc_b0 = reinterpret_cast(static_cast(nvls_mc_ptr) + nvls_off_barrier); uint64_t* dev_b0 = reinterpret_cast(static_cast(nvls_dev_ptr) + nvls_off_barrier); uint64_t pre; asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(pre) : "l"(dev_b0) : "memory"); asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b0) : "memory"); - printf("[nvls] rank=%d enter b0 epoch=%llu pre=%llu expected=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)pre, (unsigned long long)(nvls_epoch * (uint64_t)num_ranks)); + (void)0; const uint64_t expected = nvls_epoch * static_cast(num_ranks); uint64_t v; int spin = 0; do { asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b0) : "memory"); if ((++spin & 0xFFFFFFF) == 0) { - printf("[nvls] rank=%d spinning b0 v=%llu expected=%llu spin=%d\n", rank, (unsigned long long)v, (unsigned long long)expected, spin); + (void)0; } } while (v < expected); - if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b0 v=%llu expected=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)v, (unsigned long long)expected); + if (rank == 0) (void)0; } } else if (run_barrier) { port_channel_handles[barrier_channel_idx].signal(); @@ -386,7 +386,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co // NVLS epoch barrier #1: ensure every sender's multimem.st has been // delivered to all peers before we read. if (thread_id == 0) { - if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b1\n", rank, (unsigned long long)nvls_epoch); + if (rank == 0) (void)0; uint64_t* mc_b1 = reinterpret_cast(static_cast(nvls_mc_ptr) + nvls_off_barrier + 8); uint64_t* dev_b1 = reinterpret_cast(static_cast(nvls_dev_ptr) + nvls_off_barrier + 8); asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b1) : "memory"); @@ -395,7 +395,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co do { asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b1) : "memory"); } while (v < expected); - if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b1\n", rank, (unsigned long long)nvls_epoch); + if (rank == 0) (void)0; } __syncthreads(); // Receiver: for each sender s, copy num_elems ints from the NVLS slot @@ -507,7 +507,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co // Finally barrier __syncthreads(); - if (thread_id == 0) printf("[ph4-N] rank=%d enter final-barrier run=%d ch=%d\n", rank, (int)run_barrier, barrier_channel_idx); + if (thread_id == 0) (void)0; if (nvls_mc_ptr != nullptr) { // Phase 4 / NVLS HT "B2": replace the cross-node port_channel @@ -516,7 +516,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co // CX-7 RoCE that hangs the entire CTA at the !kLowLatencyMode // __syncthreads below (thread 33's wait never completes). if (thread_id == 0) { - if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b2\n", rank, (unsigned long long)nvls_epoch); + if (rank == 0) (void)0; uint64_t* mc_b2 = reinterpret_cast(static_cast(nvls_mc_ptr) + nvls_off_barrier + 16); uint64_t* dev_b2 = reinterpret_cast(static_cast(nvls_dev_ptr) + nvls_off_barrier + 16); asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b2) : "memory"); @@ -525,13 +525,13 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co do { asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b2) : "memory"); } while (v < expected); - if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b2 v=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)v); + if (rank == 0) (void)0; } } else if (run_barrier) { port_channel_handles[barrier_channel_idx].signal(); port_channel_handles[barrier_channel_idx].wait(); } - if (thread_id == 0) printf("[ph4-N] rank=%d pass final-barrier\n", rank); + if (thread_id == 0) (void)0; if constexpr (!kLowLatencyMode) { // kLowLatencyMode==false requires sync of all ranks, which can be done by running intra-node sync // after the inter-node sync is done. @@ -539,7 +539,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co } barrier_device(task_fifo_ptrs, head, nvl_rank); move_fifo_slots(head); - if (thread_id == 0) printf("[ph4-N] rank=%d notify_dispatch DONE sm=%d\n", rank, sm_id); + if (thread_id == 0) (void)0; } else { // Calculate meta data int dst_rdma_rank = sm_id - 1; @@ -693,8 +693,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV auto target_rank = role_meta.second; // Not applicable for RDMA senders EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS); if (thread_id == 0 && sm_id == 0) { - printf("[ph4-K] dispatch entry rank=%d sm=%d num_channels=%d kNumRDMARanks=%d\n", - rank, sm_id, num_channels, kNumRDMARanks); + (void)0; } // Data checks @@ -976,8 +975,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV // Synchronize shared memory sync_rdma_sender_smem(); if (lane_id == 0 && channel_id == 0 && rank == 0) { - printf("[ph4-s] sender entry rank=%d ch=%d kNumRDMARanks=%d\n", - rank, channel_id, kNumRDMARanks); + (void)0; } // Get number of tokens to send for each RDMA rank @@ -1123,8 +1121,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV } // Phase 4 diag: report final issued tail per (channel, dst_rdma_rank). if (lane_id < kNumRDMARanks && rank == 0) { - printf("[ph4-Sx] sender exit rank=%d ch=%d dst=%d issued_total=%d\n", - rank, channel_id, lane_id, last_issued_tail); + (void)0; } } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { // RDMA consumers and NVL producers @@ -1160,8 +1157,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); // Phase 4 diag: report received expected token count per (channel, src). if (rank == 4) { - printf("[ph4-Fx] fwd meta rank=%d ch=%d dst_nvl=%d src_rdma=%d expect_rdma=%d\n", - rank, channel_id, dst_nvl_rank, lane_id, num_tokens_to_recv_from_rdma); + (void)0; } break; } @@ -1243,9 +1239,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); if (rank == 4 && lane_id == 0 && dst_nvl_rank == 0 && channel_id == 0) { - printf("[ph4-Loop] fwd loop rank=%d ch=%d dst_nvl=%d src_rdma=%d head=%d tail=%d nrecv=%d\n", - rank, channel_id, dst_nvl_rank, src_rdma_rank, src_rdma_head, src_rdma_tail, - num_tokens_to_recv_from_rdma); + (void)0; } // Iterate over every token from the RDMA buffer @@ -1314,7 +1308,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV if (lane_id == 0) { forward_channel_retired[dst_nvl_rank] = true; if (channel_id == 0) { - printf("[ph4-Fr] fwd retire rank=%d ch=%d dst_nvl=%d\n", rank, channel_id, dst_nvl_rank); + (void)0; } } } else if (warp_role == WarpRole::kForwarderCoordinator) { @@ -1342,7 +1336,18 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV if (__all_sync(0xffffffff, min_head == std::numeric_limits::max())) break; // Update remote head - if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and + // Phase 4 perf: lower the lazy-update threshold from chunk_send → chunk_send/4 + // (or 1 if retired) so the sender's receive-buffer-space window + // advances more frequently. The original threshold caused partial + // last chunks to never trigger head feedback, which deadlocked at + // larger chunk_send values where 4096 tokens / 10 channels / 2 peers + // ≈ 205 tokens ⇒ 3 full chunks + 1 partial. + const bool any_retired = forward_channel_retired[0] || forward_channel_retired[1] || + forward_channel_retired[2] || forward_channel_retired[3] || + forward_channel_retired[4] || forward_channel_retired[5] || + forward_channel_retired[6] || forward_channel_retired[7]; + const int head_update_threshold = any_retired ? 1 : max(1, num_max_rdma_chunked_send_tokens / 4); + if (min_head != std::numeric_limits::max() and min_head >= last_head + head_update_threshold and lane_id < kNumRDMARanks) { if (peer_rdma_bases != nullptr && lane_id != rdma_rank) { // Phase 4 fix: cross-node head feedback via direct fabric-VA @@ -1476,7 +1481,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV } } if (thread_id == 0 && channel_id == 0) { - printf("[ph4-Z] dispatch exit rank=%d sm=%d\n", rank, sm_id); + (void)0; } } @@ -1549,7 +1554,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in // Using two SMs, which clean the RDMA/NVL buffer respectively if (sm_id == 0) { if (thread_id == 0) { - printf("[ph4-CN0] cached_notify sm0 entry rank=%d\n", rank); + (void)0; } // Barrier for RDMA — Phase 4: NVLS multimem.red.add fast path replaces // the port_channel signal/wait pair (broken on Azure CX-7). @@ -1800,8 +1805,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, // NOTES: we decouple a channel into 2 SMs const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; if (rank == 0 && thread_id == 0 && channel_id == 0) { - printf("[ph4-CK] combine entry rank=%d sm=%d num_channels=%d is_rdma_recv=%d\n", - rank, sm_id, num_channels, (int)is_rdma_receiver_sm); + (void)0; } auto role_meta = [=]() -> std::pair { auto warp_id = thread_id / 32; @@ -1831,7 +1835,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1); auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; if (thread_id == 0 && channel_id == 0 && rank == 0) { - printf("[ph4-C] combine entry rank=%d sm=%d num_channels=%d\n", rank, sm_id, num_channels); + (void)0; } if (warp_role == WarpRole::kNVLSender) {