From 0e46d3052ad46a62ac239be5f28e6dd56b9f9cce Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Fri, 8 May 2026 18:38:08 +0000 Subject: [PATCH] ext/ep: per-warp_group dispatch profile slots + sampled readback Extend the LL dispatch profile from 4 to 16 per-block uint64_t slots so we can observe per-warp_group wait and unpack times separately: [b*16 + 0] send phase entry [b*16 + 1] send phase done [b*16 + 2] kernel done [b*16 + 4 + wg] warp_group wg unpack-start (recv-count ack received) [b*16 + 8 + wg] warp_group wg unpack-end (after for-loop) This decomposes the previous lumped 'unpack' (kernel_exit minus last_wg_ack) into the actual per-warp_group copy time and the per- warp_group network wait, exposing that the copy itself is ~6 us / wg and the bulk of the time is network arrival jitter. Move the combine prof_buf base from offset 65536 to 196608 so it sits past the dispatch prof_buf (1024 + 1024*16*8 = 132096), preserving the no-collision guarantee with dispatch's atomic counters. Add MSCCLPP_EP_LL_PROFILE_PRINT_EVERY=N (default 1) so readback / cudaMemcpy / cudaMemset only run every Nth call. With EVERY=29 over a 30-iter bench (10 warmup + 20 timed) the readback fires once and the PROFILE-ON BW penalty drops from ~10% to ~3% vs PROFILE-OFF, making the instrumentation usable during real benchmarks. --- src/ext/ep/buffer.cc | 56 +++++++++++++++++++++-------- src/ext/ep/kernels/internode_ll.cu | 57 +++++++++++++++++++++--------- 2 files changed, 82 insertions(+), 31 deletions(-) diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index b17691a3..085404bc 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -1528,11 +1528,19 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i const char* e = std::getenv("MSCCLPP_EP_LL_PROFILE_PRINT"); return e != nullptr && std::string(e) == "1"; }(); + static const int kProfilePrintEvery = []() { + const char* e = std::getenv("MSCCLPP_EP_LL_PROFILE_PRINT_EVERY"); + int v = (e != nullptr) ? std::atoi(e) : 1; + return v > 0 ? v : 1; + }(); if (kProfilePrintDisp) { static int call_idx = 0; + int this_call = call_idx++; + if (this_call % kProfilePrintEvery == 0) { cudaStreamSynchronize(launch_stream); constexpr int kMaxBlocks = 1024; - std::vector host_ts(kMaxBlocks * 4); + constexpr int kProfSlots = 16; + std::vector host_ts(kMaxBlocks * kProfSlots); cudaMemcpy(host_ts.data(), static_cast(workspace) + 1024, host_ts.size() * sizeof(uint64_t), @@ -1550,8 +1558,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i double sum = 0; int n = 0; for (int b = 0; b < kMaxBlocks; ++b) { - uint64_t lo = host_ts[b * 4 + idx_lo]; - uint64_t hi = host_ts[b * 4 + idx_hi]; + uint64_t lo = host_ts[b * kProfSlots + idx_lo]; + uint64_t hi = host_ts[b * kProfSlots + idx_hi]; if (lo == 0 || hi <= lo) continue; uint64_t d = hi - lo; if (d < mn) mn = d; @@ -1565,23 +1573,33 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i to_us(mx, sm_clock_khz), n); }; + // Slot map: 0=entry, 1=send-done, 2=kernel-exit, 4..6=wg unpack-start, 8..10=wg unpack-end. auto [send_mn, send_av, send_mx, send_n] = stats(0, 1); - auto [wait_mn, wait_av, wait_mx, wait_n] = stats(1, 2); - auto [unp_mn, unp_av, unp_mx, unp_n ] = stats(2, 3); - auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 3); + auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 2); + // Per-wg wait (1 -> wg start) and unpack (wg start -> wg end). + constexpr int kMaxWg = 3; + double wait_av[kMaxWg]={0}, wait_mx[kMaxWg]={0}, unp_av[kMaxWg]={0}, unp_mx[kMaxWg]={0}; + int wait_n[kMaxWg]={0}, unp_n[kMaxWg]={0}; + for (int wg = 0; wg < kMaxWg; ++wg) { + auto [_a, av_w, mx_w, n_w] = stats(1, 4 + wg); + auto [_b, av_u, mx_u, n_u] = stats(4 + wg, 8 + wg); + wait_av[wg]=av_w; wait_mx[wg]=mx_w; wait_n[wg]=n_w; + unp_av[wg]=av_u; unp_mx[wg]=mx_u; unp_n[wg]=n_u; + } fprintf(stderr, "[ep-prof dispatch #%d r%d] blocks=%d " "send=%.1f/%.1f/%.1fus " - "wait=%.1f/%.1f/%.1fus " - "unpack=%.1f/%.1f/%.1fus " - "total=%.1f/%.1f/%.1fus (min/avg/max)\n", - call_idx++, rank, send_n, + "wait_wg(avg/max)=%.1f/%.1f, %.1f/%.1f, %.1f/%.1f us " + "unpack_wg(avg/max)=%.1f/%.1f, %.1f/%.1f, %.1f/%.1f us " + "total=%.1f/%.1f/%.1fus\n", + this_call, rank, send_n, send_mn, send_av, send_mx, - wait_mn, wait_av, wait_mx, - unp_mn, unp_av, unp_mx, + wait_av[0], wait_mx[0], wait_av[1], wait_mx[1], wait_av[2], wait_mx[2], + unp_av[0], unp_mx[0], unp_av[1], unp_mx[1], unp_av[2], unp_mx[2], tot_mn, tot_av, tot_mx); cudaMemsetAsync(static_cast(workspace) + 1024, 0, host_ts.size() * sizeof(uint64_t), launch_stream); + } } #endif @@ -1678,13 +1696,20 @@ std::tuple, std::optional 0 ? v : 1; + }(); + int this_call = call_idx++; + if (this_call % kProfilePrintEveryC == 0) { cudaStreamSynchronize(launch_stream); // Worst-case block count: combine launcher uses kNumWarpGroupsRdma=3, so // num_sms = ceil(num_experts / 3). Read 1024 entries to be safe. constexpr int kMaxBlocks = 1024; std::vector host_ts(kMaxBlocks * 4); cudaMemcpy(host_ts.data(), - static_cast(workspace) + 65536, + static_cast(workspace) + 196608, host_ts.size() * sizeof(uint64_t), cudaMemcpyDeviceToHost); static int sm_clock_khz = []() { @@ -1726,14 +1751,15 @@ std::tuple, std::optional(workspace) + 65536, + cudaMemsetAsync(static_cast(workspace) + 196608, 0, host_ts.size() * sizeof(uint64_t), launch_stream); + } } #endif diff --git a/src/ext/ep/kernels/internode_ll.cu b/src/ext/ep/kernels/internode_ll.cu index cbe12cd4..c1260d9a 100644 --- a/src/ext/ep/kernels/internode_ll.cu +++ b/src/ext/ep/kernels/internode_ll.cu @@ -184,11 +184,14 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis // Profile timestamps. Reuse workspace tail at byte offset 1024 (well past // the atomic_counter / atomic_finish_counter arrays, which use // num_experts * 4 * 2 = 512 bytes for typical num_experts <= 64). - // Layout: [block][4]uint64_t. - // [b*4 + 0] = send phase entry - // [b*4 + 1] = send phase done (after count write barrier) - // [b*4 + 2] = recv-count spinwait done - // [b*4 + 3] = kernel done + // Per-block layout: 16 uint64_t slots (1024 blocks * 16 * 8 = 128 KiB). + // [b*16 + 0] = send phase entry + // [b*16 + 1] = send phase done (after count write barrier) + // [b*16 + 2] = kernel done + // [b*16 + 3] = spare + // [b*16 + 4 + wg] = warp_group `wg` unpack-start (recv-count ack) + // [b*16 + 8 + wg] = warp_group `wg` unpack-end (after for-loop) + constexpr int kProfSlots = 16; uint64_t* prof_buf = reinterpret_cast( reinterpret_cast(atomic_counter_per_expert) + 1024); #endif @@ -197,7 +200,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; #ifdef MSCCLPP_EP_LL_PROFILE - if (thread_id == 0) prof_buf[sm_id * 4 + 0] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kProfSlots + 0] = clock64(); #endif // Expert counts @@ -217,11 +220,12 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); - auto dst_expert_idx = - warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; -// FP8 cast +// FP8 cast (or BF16 copy). Per-token staging into rdma_x[token_idx]. We +// intentionally do NOT bar.sync between tokens here -- a single barrier +// at the end of the per-token-loop covers all stages, halving the heavy +// 928-thread named barriers' contribution to the send phase. #pragma unroll for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { auto int4_value = __ldg(x_int4 + i); @@ -252,7 +256,20 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis rdma_x_vec[i] = *reinterpret_cast(&int4_value); } } - asm volatile("bar.sync 1, %0;" ::"r"(num_threads)); + } + // Single block-level barrier across the staging warps (warp_id < + // num_warps - 1, i.e. 928 threads). Replaces the previous per-token + // bar.sync that fired num_tokens_per_block times. + asm volatile("bar.sync 1, %0;" ::"r"(num_threads)); + + // Issue all sends for this block's tokens. Warps with `warp_id < + // num_topk` each map to one of the per-token top-k destinations and + // issue one IB write per token. + for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { + const auto rdma_x_src_idx = + reinterpret_cast(reinterpret_cast(rdma_x) + token_idx * num_bytes_per_msg); + auto dst_expert_idx = + warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; // Issue sends if (dst_expert_idx >= 0) { @@ -397,7 +414,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis #ifdef MSCCLPP_EP_LL_PROFILE __syncthreads(); - if (thread_id == 0) prof_buf[sm_id * 4 + 1] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kProfSlots + 1] = clock64(); #endif // Receiving phase @@ -438,9 +455,9 @@ LOW_LATENCY_DISPATCH_RECV: recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; #ifdef MSCCLPP_EP_LL_PROFILE - // Per-block: timestamp once across warp_group_id 0 sub_warp_id 1. - if (warp_group_id == 0 && sub_warp_id == 1 && lane_id == 0) { - prof_buf[sm_id * 4 + 2] = clock64(); + // Per-warp_group unpack-start (recv-count ack-received) timestamp. + if (sub_warp_id == 1 && lane_id == 0) { + prof_buf[sm_id * kProfSlots + 4 + warp_group_id] = clock64(); } #endif @@ -464,10 +481,17 @@ LOW_LATENCY_DISPATCH_RECV: (lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f; } } +#ifdef MSCCLPP_EP_LL_PROFILE + // Per-warp_group unpack-end timestamp (one writer per wg). + asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); + if (sub_warp_id == 0 && lane_id == 0) { + prof_buf[sm_id * kProfSlots + 8 + warp_group_id] = clock64(); + } +#endif } #ifdef MSCCLPP_EP_LL_PROFILE __syncthreads(); - if (thread_id == 0) prof_buf[sm_id * 4 + 3] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kProfSlots + 2] = clock64(); #endif } @@ -600,8 +624,9 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com // [b*4 + 1] = send phase done (after trailing flag write barrier) // [b*4 + 2] = recv-flag spinwait done // [b*4 + 3] = kernel done + // Offset 196608 is past dispatch's prof_buf (1024 + 1024*16*8 = 132096). uint64_t* prof_buf = reinterpret_cast( - reinterpret_cast(atomic_clean_flag) + 65536); + reinterpret_cast(atomic_clean_flag) + 196608); #endif if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV;