ext/ep: split combine 'reduce' bucket into grid_sync + reduce-arith

Add a per-block timestamp slot just after cg::this_grid().sync() in the
combine kernel so the previous 'reduce' window is decomposed into:
  [b*8 + 0] send entry
  [b*8 + 1] send done
  [b*8 + 2] recv-flag spinwait done
  [b*8 + 3] grid_sync done
  [b*8 + 4] kernel done

At TOKENS=128/TOPK=8 IBGDA this reveals the breakdown is
  send=23us  wait=200us  grid_sync=110us  reduce=22us  total=355us
not the previously assumed 'reduce=130us' (which lumped grid_sync into
the arithmetic). The actual int4-load + bf16 FMA pass is only ~22us, so
a TMA-pipelined receive cannot meaningfully recover the gap vs nccl-ep:
the difference is on the sender / RDMA arrival side, not the reducer.
This commit is contained in:
Qinghua Zhou
2026-05-08 21:50:00 +00:00
parent f63bf15378
commit b557fe289a
2 changed files with 23 additions and 14 deletions

View File

@@ -1707,7 +1707,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
// Worst-case block count: combine launcher uses kNumWarpGroupsRdma=3, so
// num_sms = ceil(num_experts / 3). Read 1024 entries to be safe.
constexpr int kMaxBlocks = 1024;
std::vector<uint64_t> host_ts(kMaxBlocks * 4);
constexpr int kCombSlots = 8;
std::vector<uint64_t> host_ts(kMaxBlocks * kCombSlots);
cudaMemcpy(host_ts.data(),
static_cast<char*>(workspace) + 196608,
host_ts.size() * sizeof(uint64_t),
@@ -1726,8 +1727,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
double sum = 0;
int n = 0;
for (int b = 0; b < kMaxBlocks; ++b) {
uint64_t lo = host_ts[b * 4 + idx_lo];
uint64_t hi = host_ts[b * 4 + idx_hi];
uint64_t lo = host_ts[b * kCombSlots + idx_lo];
uint64_t hi = host_ts[b * kCombSlots + idx_hi];
if (lo == 0 || hi <= lo) continue;
uint64_t d = hi - lo;
if (d < mn) mn = d;
@@ -1743,17 +1744,20 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
};
auto [send_mn, send_av, send_mx, send_n] = stats(0, 1);
auto [wait_mn, wait_av, wait_mx, wait_n] = stats(1, 2);
auto [redu_mn, redu_av, redu_mx, redu_n] = stats(2, 3);
auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 3);
auto [gsy_mn, gsy_av, gsy_mx, gsy_n ] = stats(2, 3);
auto [redu_mn, redu_av, redu_mx, redu_n] = stats(3, 4);
auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 4);
fprintf(stderr,
"[ep-prof combine #%d r%d] blocks=%d "
"send=%.1f/%.1f/%.1fus "
"wait=%.1f/%.1f/%.1fus "
"grid_sync=%.1f/%.1f/%.1fus "
"reduce=%.1f/%.1f/%.1fus "
"total=%.1f/%.1f/%.1fus (min/avg/max)\n",
this_call, rank, send_n,
send_mn, send_av, send_mx,
wait_mn, wait_av, wait_mx,
gsy_mn, gsy_av, gsy_mx,
redu_mn, redu_av, redu_mx,
tot_mn, tot_av, tot_mx);
// Zero out so next iter's "lo==0" filter rejects untouched slots.

View File

@@ -619,12 +619,14 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
// combine prof_buf well past that to avoid corrupting dispatch's
// atomic counters between iterations (which previously caused iter 1+
// to hang in dispatch's FINISHED_SUM_TAG spinwait).
// Per-block [4]uint64_t scratch. Slots:
// [b*4 + 0] = send phase entry
// [b*4 + 1] = send phase done (after trailing flag write barrier)
// [b*4 + 2] = recv-flag spinwait done
// [b*4 + 3] = kernel done
// Per-block scratch. Slots:
// [b*8 + 0] = send phase entry
// [b*8 + 1] = send phase done (after trailing flag write barrier)
// [b*8 + 2] = recv-flag spinwait done
// [b*8 + 3] = grid-sync done (entry to reduce arithmetic)
// [b*8 + 4] = kernel done
// Offset 196608 is past dispatch's prof_buf (1024 + 1024*16*8 = 132096).
constexpr int kCombProfSlots = 8;
uint64_t* prof_buf = reinterpret_cast<uint64_t*>(
reinterpret_cast<char*>(atomic_clean_flag) + 196608);
#endif
@@ -632,7 +634,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV;
#ifdef MSCCLPP_EP_LL_PROFILE
if (thread_id == 0) prof_buf[sm_id * 4 + 0] = clock64();
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 0] = clock64();
#endif
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
@@ -748,7 +750,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
#ifdef MSCCLPP_EP_LL_PROFILE
__syncthreads();
if (thread_id == 0) prof_buf[sm_id * 4 + 1] = clock64();
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 1] = clock64();
#endif
LOW_LATENCY_COMBINE_RECV:
@@ -762,9 +764,12 @@ LOW_LATENCY_COMBINE_RECV:
}
#ifdef MSCCLPP_EP_LL_PROFILE
__syncthreads();
if (thread_id == 0) prof_buf[sm_id * 4 + 2] = clock64();
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 2] = clock64();
#endif
cg::this_grid().sync();
#ifdef MSCCLPP_EP_LL_PROFILE
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 3] = clock64();
#endif
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
@@ -803,7 +808,7 @@ LOW_LATENCY_COMBINE_RECV:
}
#ifdef MSCCLPP_EP_LL_PROFILE
__syncthreads();
if (thread_id == 0) prof_buf[sm_id * 4 + 3] = clock64();
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 4] = clock64();
#endif
}