From b557fe289aba9f788e879fdd6a34bd9f34c09457 Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Fri, 8 May 2026 21:50:00 +0000 Subject: [PATCH] ext/ep: split combine 'reduce' bucket into grid_sync + reduce-arith Add a per-block timestamp slot just after cg::this_grid().sync() in the combine kernel so the previous 'reduce' window is decomposed into: [b*8 + 0] send entry [b*8 + 1] send done [b*8 + 2] recv-flag spinwait done [b*8 + 3] grid_sync done [b*8 + 4] kernel done At TOKENS=128/TOPK=8 IBGDA this reveals the breakdown is send=23us wait=200us grid_sync=110us reduce=22us total=355us not the previously assumed 'reduce=130us' (which lumped grid_sync into the arithmetic). The actual int4-load + bf16 FMA pass is only ~22us, so a TMA-pipelined receive cannot meaningfully recover the gap vs nccl-ep: the difference is on the sender / RDMA arrival side, not the reducer. --- src/ext/ep/buffer.cc | 14 +++++++++----- src/ext/ep/kernels/internode_ll.cu | 23 ++++++++++++++--------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index 085404bc..51cd40bb 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -1707,7 +1707,8 @@ std::tuple, std::optional host_ts(kMaxBlocks * 4); + constexpr int kCombSlots = 8; + std::vector host_ts(kMaxBlocks * kCombSlots); cudaMemcpy(host_ts.data(), static_cast(workspace) + 196608, host_ts.size() * sizeof(uint64_t), @@ -1726,8 +1727,8 @@ std::tuple, std::optional, std::optional( reinterpret_cast(atomic_clean_flag) + 196608); #endif @@ -632,7 +634,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV; #ifdef MSCCLPP_EP_LL_PROFILE - if (thread_id == 0) prof_buf[sm_id * 4 + 0] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 0] = clock64(); #endif if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { @@ -748,7 +750,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com #ifdef MSCCLPP_EP_LL_PROFILE __syncthreads(); - if (thread_id == 0) prof_buf[sm_id * 4 + 1] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 1] = clock64(); #endif LOW_LATENCY_COMBINE_RECV: @@ -762,9 +764,12 @@ LOW_LATENCY_COMBINE_RECV: } #ifdef MSCCLPP_EP_LL_PROFILE __syncthreads(); - if (thread_id == 0) prof_buf[sm_id * 4 + 2] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 2] = clock64(); #endif cg::this_grid().sync(); +#ifdef MSCCLPP_EP_LL_PROFILE + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 3] = clock64(); +#endif EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads); EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization"); @@ -803,7 +808,7 @@ LOW_LATENCY_COMBINE_RECV: } #ifdef MSCCLPP_EP_LL_PROFILE __syncthreads(); - if (thread_id == 0) prof_buf[sm_id * 4 + 3] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 4] = clock64(); #endif }