mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
ext/ep: HT perf - lower lazy head-feedback threshold to chunk/4
Problem: kForwarderCoordinator only published min_head when min_head >= last_head + num_max_rdma_chunked_send_tokens. With 4096 tokens / 10 channels / 2 peers ~= 205 tokens per (channel,peer), the receive-buffer-space window advanced only every full chunk, and the last partial chunk never triggered an update, serializing handshakes and capping HT throughput. Fix: lower threshold to max(1, chunk/4); when any forwarder channel has retired, drop to 1 so partial tail chunks always publish. Result on 2-node GB200 NVL72 (cfg=20,8,256,16,128, 4096 tok / 7168 hidden): dispatch agg: 36 -> 78 GB/s (2.16x) combine agg: 45 -> 95 GB/s (2.11x) PASS, max diff = 0 Also strips Phase 4 diagnostic printfs.
This commit is contained in:
@@ -292,23 +292,23 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
|
||||
// mapping until the slot reaches `epoch * num_ranks` (everyone
|
||||
// arrived). Replaces the pairwise PortChannel signal/wait above.
|
||||
if (thread_id == 0) {
|
||||
if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b0\n", rank, (unsigned long long)nvls_epoch);
|
||||
if (rank == 0) (void)0;
|
||||
uint64_t* mc_b0 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier);
|
||||
uint64_t* dev_b0 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier);
|
||||
uint64_t pre;
|
||||
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(pre) : "l"(dev_b0) : "memory");
|
||||
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b0) : "memory");
|
||||
printf("[nvls] rank=%d enter b0 epoch=%llu pre=%llu expected=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)pre, (unsigned long long)(nvls_epoch * (uint64_t)num_ranks));
|
||||
(void)0;
|
||||
const uint64_t expected = nvls_epoch * static_cast<uint64_t>(num_ranks);
|
||||
uint64_t v;
|
||||
int spin = 0;
|
||||
do {
|
||||
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b0) : "memory");
|
||||
if ((++spin & 0xFFFFFFF) == 0) {
|
||||
printf("[nvls] rank=%d spinning b0 v=%llu expected=%llu spin=%d\n", rank, (unsigned long long)v, (unsigned long long)expected, spin);
|
||||
(void)0;
|
||||
}
|
||||
} while (v < expected);
|
||||
if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b0 v=%llu expected=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)v, (unsigned long long)expected);
|
||||
if (rank == 0) (void)0;
|
||||
}
|
||||
} else if (run_barrier) {
|
||||
port_channel_handles[barrier_channel_idx].signal();
|
||||
@@ -386,7 +386,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
|
||||
// NVLS epoch barrier #1: ensure every sender's multimem.st has been
|
||||
// delivered to all peers before we read.
|
||||
if (thread_id == 0) {
|
||||
if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b1\n", rank, (unsigned long long)nvls_epoch);
|
||||
if (rank == 0) (void)0;
|
||||
uint64_t* mc_b1 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier + 8);
|
||||
uint64_t* dev_b1 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier + 8);
|
||||
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b1) : "memory");
|
||||
@@ -395,7 +395,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
|
||||
do {
|
||||
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b1) : "memory");
|
||||
} while (v < expected);
|
||||
if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b1\n", rank, (unsigned long long)nvls_epoch);
|
||||
if (rank == 0) (void)0;
|
||||
}
|
||||
__syncthreads();
|
||||
// Receiver: for each sender s, copy num_elems ints from the NVLS slot
|
||||
@@ -507,7 +507,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
|
||||
|
||||
// Finally barrier
|
||||
__syncthreads();
|
||||
if (thread_id == 0) printf("[ph4-N] rank=%d enter final-barrier run=%d ch=%d\n", rank, (int)run_barrier, barrier_channel_idx);
|
||||
if (thread_id == 0) (void)0;
|
||||
|
||||
if (nvls_mc_ptr != nullptr) {
|
||||
// Phase 4 / NVLS HT "B2": replace the cross-node port_channel
|
||||
@@ -516,7 +516,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
|
||||
// CX-7 RoCE that hangs the entire CTA at the !kLowLatencyMode
|
||||
// __syncthreads below (thread 33's wait never completes).
|
||||
if (thread_id == 0) {
|
||||
if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b2\n", rank, (unsigned long long)nvls_epoch);
|
||||
if (rank == 0) (void)0;
|
||||
uint64_t* mc_b2 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier + 16);
|
||||
uint64_t* dev_b2 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier + 16);
|
||||
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b2) : "memory");
|
||||
@@ -525,13 +525,13 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
|
||||
do {
|
||||
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b2) : "memory");
|
||||
} while (v < expected);
|
||||
if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b2 v=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)v);
|
||||
if (rank == 0) (void)0;
|
||||
}
|
||||
} else if (run_barrier) {
|
||||
port_channel_handles[barrier_channel_idx].signal();
|
||||
port_channel_handles[barrier_channel_idx].wait();
|
||||
}
|
||||
if (thread_id == 0) printf("[ph4-N] rank=%d pass final-barrier\n", rank);
|
||||
if (thread_id == 0) (void)0;
|
||||
if constexpr (!kLowLatencyMode) {
|
||||
// kLowLatencyMode==false requires sync of all ranks, which can be done by running intra-node sync
|
||||
// after the inter-node sync is done.
|
||||
@@ -539,7 +539,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
|
||||
}
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
if (thread_id == 0) printf("[ph4-N] rank=%d notify_dispatch DONE sm=%d\n", rank, sm_id);
|
||||
if (thread_id == 0) (void)0;
|
||||
} else {
|
||||
// Calculate meta data
|
||||
int dst_rdma_rank = sm_id - 1;
|
||||
@@ -693,8 +693,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
auto target_rank = role_meta.second; // Not applicable for RDMA senders
|
||||
EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS);
|
||||
if (thread_id == 0 && sm_id == 0) {
|
||||
printf("[ph4-K] dispatch entry rank=%d sm=%d num_channels=%d kNumRDMARanks=%d\n",
|
||||
rank, sm_id, num_channels, kNumRDMARanks);
|
||||
(void)0;
|
||||
}
|
||||
|
||||
// Data checks
|
||||
@@ -976,8 +975,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
// Synchronize shared memory
|
||||
sync_rdma_sender_smem();
|
||||
if (lane_id == 0 && channel_id == 0 && rank == 0) {
|
||||
printf("[ph4-s] sender entry rank=%d ch=%d kNumRDMARanks=%d\n",
|
||||
rank, channel_id, kNumRDMARanks);
|
||||
(void)0;
|
||||
}
|
||||
|
||||
// Get number of tokens to send for each RDMA rank
|
||||
@@ -1123,8 +1121,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
}
|
||||
// Phase 4 diag: report final issued tail per (channel, dst_rdma_rank).
|
||||
if (lane_id < kNumRDMARanks && rank == 0) {
|
||||
printf("[ph4-Sx] sender exit rank=%d ch=%d dst=%d issued_total=%d\n",
|
||||
rank, channel_id, lane_id, last_issued_tail);
|
||||
(void)0;
|
||||
}
|
||||
} else if (warp_role == WarpRole::kRDMAAndNVLForwarder) {
|
||||
// RDMA consumers and NVL producers
|
||||
@@ -1160,8 +1157,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
|
||||
// Phase 4 diag: report received expected token count per (channel, src).
|
||||
if (rank == 4) {
|
||||
printf("[ph4-Fx] fwd meta rank=%d ch=%d dst_nvl=%d src_rdma=%d expect_rdma=%d\n",
|
||||
rank, channel_id, dst_nvl_rank, lane_id, num_tokens_to_recv_from_rdma);
|
||||
(void)0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -1243,9 +1239,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank);
|
||||
|
||||
if (rank == 4 && lane_id == 0 && dst_nvl_rank == 0 && channel_id == 0) {
|
||||
printf("[ph4-Loop] fwd loop rank=%d ch=%d dst_nvl=%d src_rdma=%d head=%d tail=%d nrecv=%d\n",
|
||||
rank, channel_id, dst_nvl_rank, src_rdma_rank, src_rdma_head, src_rdma_tail,
|
||||
num_tokens_to_recv_from_rdma);
|
||||
(void)0;
|
||||
}
|
||||
|
||||
// Iterate over every token from the RDMA buffer
|
||||
@@ -1314,7 +1308,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
if (lane_id == 0) {
|
||||
forward_channel_retired[dst_nvl_rank] = true;
|
||||
if (channel_id == 0) {
|
||||
printf("[ph4-Fr] fwd retire rank=%d ch=%d dst_nvl=%d\n", rank, channel_id, dst_nvl_rank);
|
||||
(void)0;
|
||||
}
|
||||
}
|
||||
} else if (warp_role == WarpRole::kForwarderCoordinator) {
|
||||
@@ -1342,7 +1336,18 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
if (__all_sync(0xffffffff, min_head == std::numeric_limits<int>::max())) break;
|
||||
|
||||
// Update remote head
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and
|
||||
// Phase 4 perf: lower the lazy-update threshold from chunk_send → chunk_send/4
|
||||
// (or 1 if retired) so the sender's receive-buffer-space window
|
||||
// advances more frequently. The original threshold caused partial
|
||||
// last chunks to never trigger head feedback, which deadlocked at
|
||||
// larger chunk_send values where 4096 tokens / 10 channels / 2 peers
|
||||
// ≈ 205 tokens ⇒ 3 full chunks + 1 partial.
|
||||
const bool any_retired = forward_channel_retired[0] || forward_channel_retired[1] ||
|
||||
forward_channel_retired[2] || forward_channel_retired[3] ||
|
||||
forward_channel_retired[4] || forward_channel_retired[5] ||
|
||||
forward_channel_retired[6] || forward_channel_retired[7];
|
||||
const int head_update_threshold = any_retired ? 1 : max(1, num_max_rdma_chunked_send_tokens / 4);
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + head_update_threshold and
|
||||
lane_id < kNumRDMARanks) {
|
||||
if (peer_rdma_bases != nullptr && lane_id != rdma_rank) {
|
||||
// Phase 4 fix: cross-node head feedback via direct fabric-VA
|
||||
@@ -1476,7 +1481,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
}
|
||||
}
|
||||
if (thread_id == 0 && channel_id == 0) {
|
||||
printf("[ph4-Z] dispatch exit rank=%d sm=%d\n", rank, sm_id);
|
||||
(void)0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1549,7 +1554,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
// Using two SMs, which clean the RDMA/NVL buffer respectively
|
||||
if (sm_id == 0) {
|
||||
if (thread_id == 0) {
|
||||
printf("[ph4-CN0] cached_notify sm0 entry rank=%d\n", rank);
|
||||
(void)0;
|
||||
}
|
||||
// Barrier for RDMA — Phase 4: NVLS multimem.red.add fast path replaces
|
||||
// the port_channel signal/wait pair (broken on Azure CX-7).
|
||||
@@ -1800,8 +1805,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
|
||||
// NOTES: we decouple a channel into 2 SMs
|
||||
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
||||
if (rank == 0 && thread_id == 0 && channel_id == 0) {
|
||||
printf("[ph4-CK] combine entry rank=%d sm=%d num_channels=%d is_rdma_recv=%d\n",
|
||||
rank, sm_id, num_channels, (int)is_rdma_receiver_sm);
|
||||
(void)0;
|
||||
}
|
||||
auto role_meta = [=]() -> std::pair<WarpRole, int> {
|
||||
auto warp_id = thread_id / 32;
|
||||
@@ -1831,7 +1835,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
|
||||
EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1);
|
||||
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
|
||||
if (thread_id == 0 && channel_id == 0 && rank == 0) {
|
||||
printf("[ph4-C] combine entry rank=%d sm=%d num_channels=%d\n", rank, sm_id, num_channels);
|
||||
(void)0;
|
||||
}
|
||||
|
||||
if (warp_role == WarpRole::kNVLSender) {
|
||||
|
||||
Reference in New Issue
Block a user