ext/ep: HT perf - lower lazy head-feedback threshold to chunk/4

Problem: kForwarderCoordinator only published min_head when
min_head >= last_head + num_max_rdma_chunked_send_tokens. With
4096 tokens / 10 channels / 2 peers ~= 205 tokens per (channel,peer),
the receive-buffer-space window advanced only every full chunk, and
the last partial chunk never triggered an update, serializing
handshakes and capping HT throughput.

Fix: lower threshold to max(1, chunk/4); when any forwarder channel
has retired, drop to 1 so partial tail chunks always publish.

Result on 2-node GB200 NVL72 (cfg=20,8,256,16,128, 4096 tok / 7168 hidden):
  dispatch agg: 36 -> 78 GB/s  (2.16x)
  combine  agg: 45 -> 95 GB/s  (2.11x)
  PASS, max diff = 0

Also strips Phase 4 diagnostic printfs.
This commit is contained in:
Qinghua Zhou
2026-05-10 19:21:25 +00:00
parent e0a1bb2c42
commit 01a10e00de

View File

@@ -292,23 +292,23 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
// mapping until the slot reaches `epoch * num_ranks` (everyone
// arrived). Replaces the pairwise PortChannel signal/wait above.
if (thread_id == 0) {
if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b0\n", rank, (unsigned long long)nvls_epoch);
if (rank == 0) (void)0;
uint64_t* mc_b0 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier);
uint64_t* dev_b0 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier);
uint64_t pre;
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(pre) : "l"(dev_b0) : "memory");
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b0) : "memory");
printf("[nvls] rank=%d enter b0 epoch=%llu pre=%llu expected=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)pre, (unsigned long long)(nvls_epoch * (uint64_t)num_ranks));
(void)0;
const uint64_t expected = nvls_epoch * static_cast<uint64_t>(num_ranks);
uint64_t v;
int spin = 0;
do {
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b0) : "memory");
if ((++spin & 0xFFFFFFF) == 0) {
printf("[nvls] rank=%d spinning b0 v=%llu expected=%llu spin=%d\n", rank, (unsigned long long)v, (unsigned long long)expected, spin);
(void)0;
}
} while (v < expected);
if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b0 v=%llu expected=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)v, (unsigned long long)expected);
if (rank == 0) (void)0;
}
} else if (run_barrier) {
port_channel_handles[barrier_channel_idx].signal();
@@ -386,7 +386,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
// NVLS epoch barrier #1: ensure every sender's multimem.st has been
// delivered to all peers before we read.
if (thread_id == 0) {
if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b1\n", rank, (unsigned long long)nvls_epoch);
if (rank == 0) (void)0;
uint64_t* mc_b1 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier + 8);
uint64_t* dev_b1 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier + 8);
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b1) : "memory");
@@ -395,7 +395,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
do {
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b1) : "memory");
} while (v < expected);
if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b1\n", rank, (unsigned long long)nvls_epoch);
if (rank == 0) (void)0;
}
__syncthreads();
// Receiver: for each sender s, copy num_elems ints from the NVLS slot
@@ -507,7 +507,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
// Finally barrier
__syncthreads();
if (thread_id == 0) printf("[ph4-N] rank=%d enter final-barrier run=%d ch=%d\n", rank, (int)run_barrier, barrier_channel_idx);
if (thread_id == 0) (void)0;
if (nvls_mc_ptr != nullptr) {
// Phase 4 / NVLS HT "B2": replace the cross-node port_channel
@@ -516,7 +516,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
// CX-7 RoCE that hangs the entire CTA at the !kLowLatencyMode
// __syncthreads below (thread 33's wait never completes).
if (thread_id == 0) {
if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b2\n", rank, (unsigned long long)nvls_epoch);
if (rank == 0) (void)0;
uint64_t* mc_b2 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier + 16);
uint64_t* dev_b2 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier + 16);
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b2) : "memory");
@@ -525,13 +525,13 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
do {
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b2) : "memory");
} while (v < expected);
if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b2 v=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)v);
if (rank == 0) (void)0;
}
} else if (run_barrier) {
port_channel_handles[barrier_channel_idx].signal();
port_channel_handles[barrier_channel_idx].wait();
}
if (thread_id == 0) printf("[ph4-N] rank=%d pass final-barrier\n", rank);
if (thread_id == 0) (void)0;
if constexpr (!kLowLatencyMode) {
// kLowLatencyMode==false requires sync of all ranks, which can be done by running intra-node sync
// after the inter-node sync is done.
@@ -539,7 +539,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
}
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
if (thread_id == 0) printf("[ph4-N] rank=%d notify_dispatch DONE sm=%d\n", rank, sm_id);
if (thread_id == 0) (void)0;
} else {
// Calculate meta data
int dst_rdma_rank = sm_id - 1;
@@ -693,8 +693,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
auto target_rank = role_meta.second; // Not applicable for RDMA senders
EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS);
if (thread_id == 0 && sm_id == 0) {
printf("[ph4-K] dispatch entry rank=%d sm=%d num_channels=%d kNumRDMARanks=%d\n",
rank, sm_id, num_channels, kNumRDMARanks);
(void)0;
}
// Data checks
@@ -976,8 +975,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
// Synchronize shared memory
sync_rdma_sender_smem();
if (lane_id == 0 && channel_id == 0 && rank == 0) {
printf("[ph4-s] sender entry rank=%d ch=%d kNumRDMARanks=%d\n",
rank, channel_id, kNumRDMARanks);
(void)0;
}
// Get number of tokens to send for each RDMA rank
@@ -1123,8 +1121,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
}
// Phase 4 diag: report final issued tail per (channel, dst_rdma_rank).
if (lane_id < kNumRDMARanks && rank == 0) {
printf("[ph4-Sx] sender exit rank=%d ch=%d dst=%d issued_total=%d\n",
rank, channel_id, lane_id, last_issued_tail);
(void)0;
}
} else if (warp_role == WarpRole::kRDMAAndNVLForwarder) {
// RDMA consumers and NVL producers
@@ -1160,8 +1157,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
// Phase 4 diag: report received expected token count per (channel, src).
if (rank == 4) {
printf("[ph4-Fx] fwd meta rank=%d ch=%d dst_nvl=%d src_rdma=%d expect_rdma=%d\n",
rank, channel_id, dst_nvl_rank, lane_id, num_tokens_to_recv_from_rdma);
(void)0;
}
break;
}
@@ -1243,9 +1239,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank);
if (rank == 4 && lane_id == 0 && dst_nvl_rank == 0 && channel_id == 0) {
printf("[ph4-Loop] fwd loop rank=%d ch=%d dst_nvl=%d src_rdma=%d head=%d tail=%d nrecv=%d\n",
rank, channel_id, dst_nvl_rank, src_rdma_rank, src_rdma_head, src_rdma_tail,
num_tokens_to_recv_from_rdma);
(void)0;
}
// Iterate over every token from the RDMA buffer
@@ -1314,7 +1308,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
if (lane_id == 0) {
forward_channel_retired[dst_nvl_rank] = true;
if (channel_id == 0) {
printf("[ph4-Fr] fwd retire rank=%d ch=%d dst_nvl=%d\n", rank, channel_id, dst_nvl_rank);
(void)0;
}
}
} else if (warp_role == WarpRole::kForwarderCoordinator) {
@@ -1342,7 +1336,18 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
if (__all_sync(0xffffffff, min_head == std::numeric_limits<int>::max())) break;
// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and
// Phase 4 perf: lower the lazy-update threshold from chunk_send → chunk_send/4
// (or 1 if retired) so the sender's receive-buffer-space window
// advances more frequently. The original threshold caused partial
// last chunks to never trigger head feedback, which deadlocked at
// larger chunk_send values where 4096 tokens / 10 channels / 2 peers
// ≈ 205 tokens ⇒ 3 full chunks + 1 partial.
const bool any_retired = forward_channel_retired[0] || forward_channel_retired[1] ||
forward_channel_retired[2] || forward_channel_retired[3] ||
forward_channel_retired[4] || forward_channel_retired[5] ||
forward_channel_retired[6] || forward_channel_retired[7];
const int head_update_threshold = any_retired ? 1 : max(1, num_max_rdma_chunked_send_tokens / 4);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + head_update_threshold and
lane_id < kNumRDMARanks) {
if (peer_rdma_bases != nullptr && lane_id != rdma_rank) {
// Phase 4 fix: cross-node head feedback via direct fabric-VA
@@ -1476,7 +1481,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
}
}
if (thread_id == 0 && channel_id == 0) {
printf("[ph4-Z] dispatch exit rank=%d sm=%d\n", rank, sm_id);
(void)0;
}
}
@@ -1549,7 +1554,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
// Using two SMs, which clean the RDMA/NVL buffer respectively
if (sm_id == 0) {
if (thread_id == 0) {
printf("[ph4-CN0] cached_notify sm0 entry rank=%d\n", rank);
(void)0;
}
// Barrier for RDMA — Phase 4: NVLS multimem.red.add fast path replaces
// the port_channel signal/wait pair (broken on Azure CX-7).
@@ -1800,8 +1805,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
if (rank == 0 && thread_id == 0 && channel_id == 0) {
printf("[ph4-CK] combine entry rank=%d sm=%d num_channels=%d is_rdma_recv=%d\n",
rank, sm_id, num_channels, (int)is_rdma_receiver_sm);
(void)0;
}
auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / 32;
@@ -1831,7 +1835,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1);
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
if (thread_id == 0 && channel_id == 0 && rank == 0) {
printf("[ph4-C] combine entry rank=%d sm=%d num_channels=%d\n", rank, sm_id, num_channels);
(void)0;
}
if (warp_role == WarpRole::kNVLSender) {