mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
ext/ep: split combine 'reduce' bucket into grid_sync + reduce-arith
Add a per-block timestamp slot just after cg::this_grid().sync() in the combine kernel so the previous 'reduce' window is decomposed into: [b*8 + 0] send entry [b*8 + 1] send done [b*8 + 2] recv-flag spinwait done [b*8 + 3] grid_sync done [b*8 + 4] kernel done At TOKENS=128/TOPK=8 IBGDA this reveals the breakdown is send=23us wait=200us grid_sync=110us reduce=22us total=355us not the previously assumed 'reduce=130us' (which lumped grid_sync into the arithmetic). The actual int4-load + bf16 FMA pass is only ~22us, so a TMA-pipelined receive cannot meaningfully recover the gap vs nccl-ep: the difference is on the sender / RDMA arrival side, not the reducer.
This commit is contained in:
@@ -1707,7 +1707,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
// Worst-case block count: combine launcher uses kNumWarpGroupsRdma=3, so
|
||||
// num_sms = ceil(num_experts / 3). Read 1024 entries to be safe.
|
||||
constexpr int kMaxBlocks = 1024;
|
||||
std::vector<uint64_t> host_ts(kMaxBlocks * 4);
|
||||
constexpr int kCombSlots = 8;
|
||||
std::vector<uint64_t> host_ts(kMaxBlocks * kCombSlots);
|
||||
cudaMemcpy(host_ts.data(),
|
||||
static_cast<char*>(workspace) + 196608,
|
||||
host_ts.size() * sizeof(uint64_t),
|
||||
@@ -1726,8 +1727,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
double sum = 0;
|
||||
int n = 0;
|
||||
for (int b = 0; b < kMaxBlocks; ++b) {
|
||||
uint64_t lo = host_ts[b * 4 + idx_lo];
|
||||
uint64_t hi = host_ts[b * 4 + idx_hi];
|
||||
uint64_t lo = host_ts[b * kCombSlots + idx_lo];
|
||||
uint64_t hi = host_ts[b * kCombSlots + idx_hi];
|
||||
if (lo == 0 || hi <= lo) continue;
|
||||
uint64_t d = hi - lo;
|
||||
if (d < mn) mn = d;
|
||||
@@ -1743,17 +1744,20 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
};
|
||||
auto [send_mn, send_av, send_mx, send_n] = stats(0, 1);
|
||||
auto [wait_mn, wait_av, wait_mx, wait_n] = stats(1, 2);
|
||||
auto [redu_mn, redu_av, redu_mx, redu_n] = stats(2, 3);
|
||||
auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 3);
|
||||
auto [gsy_mn, gsy_av, gsy_mx, gsy_n ] = stats(2, 3);
|
||||
auto [redu_mn, redu_av, redu_mx, redu_n] = stats(3, 4);
|
||||
auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 4);
|
||||
fprintf(stderr,
|
||||
"[ep-prof combine #%d r%d] blocks=%d "
|
||||
"send=%.1f/%.1f/%.1fus "
|
||||
"wait=%.1f/%.1f/%.1fus "
|
||||
"grid_sync=%.1f/%.1f/%.1fus "
|
||||
"reduce=%.1f/%.1f/%.1fus "
|
||||
"total=%.1f/%.1f/%.1fus (min/avg/max)\n",
|
||||
this_call, rank, send_n,
|
||||
send_mn, send_av, send_mx,
|
||||
wait_mn, wait_av, wait_mx,
|
||||
gsy_mn, gsy_av, gsy_mx,
|
||||
redu_mn, redu_av, redu_mx,
|
||||
tot_mn, tot_av, tot_mx);
|
||||
// Zero out so next iter's "lo==0" filter rejects untouched slots.
|
||||
|
||||
@@ -619,12 +619,14 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
|
||||
// combine prof_buf well past that to avoid corrupting dispatch's
|
||||
// atomic counters between iterations (which previously caused iter 1+
|
||||
// to hang in dispatch's FINISHED_SUM_TAG spinwait).
|
||||
// Per-block [4]uint64_t scratch. Slots:
|
||||
// [b*4 + 0] = send phase entry
|
||||
// [b*4 + 1] = send phase done (after trailing flag write barrier)
|
||||
// [b*4 + 2] = recv-flag spinwait done
|
||||
// [b*4 + 3] = kernel done
|
||||
// Per-block scratch. Slots:
|
||||
// [b*8 + 0] = send phase entry
|
||||
// [b*8 + 1] = send phase done (after trailing flag write barrier)
|
||||
// [b*8 + 2] = recv-flag spinwait done
|
||||
// [b*8 + 3] = grid-sync done (entry to reduce arithmetic)
|
||||
// [b*8 + 4] = kernel done
|
||||
// Offset 196608 is past dispatch's prof_buf (1024 + 1024*16*8 = 132096).
|
||||
constexpr int kCombProfSlots = 8;
|
||||
uint64_t* prof_buf = reinterpret_cast<uint64_t*>(
|
||||
reinterpret_cast<char*>(atomic_clean_flag) + 196608);
|
||||
#endif
|
||||
@@ -632,7 +634,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
|
||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV;
|
||||
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
if (thread_id == 0) prof_buf[sm_id * 4 + 0] = clock64();
|
||||
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 0] = clock64();
|
||||
#endif
|
||||
|
||||
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
|
||||
@@ -748,7 +750,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
|
||||
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
__syncthreads();
|
||||
if (thread_id == 0) prof_buf[sm_id * 4 + 1] = clock64();
|
||||
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 1] = clock64();
|
||||
#endif
|
||||
|
||||
LOW_LATENCY_COMBINE_RECV:
|
||||
@@ -762,9 +764,12 @@ LOW_LATENCY_COMBINE_RECV:
|
||||
}
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
__syncthreads();
|
||||
if (thread_id == 0) prof_buf[sm_id * 4 + 2] = clock64();
|
||||
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 2] = clock64();
|
||||
#endif
|
||||
cg::this_grid().sync();
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 3] = clock64();
|
||||
#endif
|
||||
|
||||
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
|
||||
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
|
||||
@@ -803,7 +808,7 @@ LOW_LATENCY_COMBINE_RECV:
|
||||
}
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
__syncthreads();
|
||||
if (thread_id == 0) prof_buf[sm_id * 4 + 3] = clock64();
|
||||
if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 4] = clock64();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user