ext/ep: per-warp_group dispatch profile slots + sampled readback

Extend the LL dispatch profile from 4 to 16 per-block uint64_t slots so
we can observe per-warp_group wait and unpack times separately:
  [b*16 + 0]      send phase entry
  [b*16 + 1]      send phase done
  [b*16 + 2]      kernel done
  [b*16 + 4 + wg] warp_group wg unpack-start (recv-count ack received)
  [b*16 + 8 + wg] warp_group wg unpack-end   (after for-loop)

This decomposes the previous lumped 'unpack' (kernel_exit minus
last_wg_ack) into the actual per-warp_group copy time and the per-
warp_group network wait, exposing that the copy itself is ~6 us / wg
and the bulk of the time is network arrival jitter.

Move the combine prof_buf base from offset 65536 to 196608 so it sits
past the dispatch prof_buf (1024 + 1024*16*8 = 132096), preserving the
no-collision guarantee with dispatch's atomic counters.

Add MSCCLPP_EP_LL_PROFILE_PRINT_EVERY=N (default 1) so readback /
cudaMemcpy / cudaMemset only run every Nth call. With EVERY=29 over a
30-iter bench (10 warmup + 20 timed) the readback fires once and the
PROFILE-ON BW penalty drops from ~10% to ~3% vs PROFILE-OFF, making
the instrumentation usable during real benchmarks.
This commit is contained in:
Qinghua Zhou
2026-05-08 18:38:08 +00:00
parent fec40601b8
commit 0e46d3052a
2 changed files with 82 additions and 31 deletions

View File

@@ -1528,11 +1528,19 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
const char* e = std::getenv("MSCCLPP_EP_LL_PROFILE_PRINT");
return e != nullptr && std::string(e) == "1";
}();
static const int kProfilePrintEvery = []() {
const char* e = std::getenv("MSCCLPP_EP_LL_PROFILE_PRINT_EVERY");
int v = (e != nullptr) ? std::atoi(e) : 1;
return v > 0 ? v : 1;
}();
if (kProfilePrintDisp) {
static int call_idx = 0;
int this_call = call_idx++;
if (this_call % kProfilePrintEvery == 0) {
cudaStreamSynchronize(launch_stream);
constexpr int kMaxBlocks = 1024;
std::vector<uint64_t> host_ts(kMaxBlocks * 4);
constexpr int kProfSlots = 16;
std::vector<uint64_t> host_ts(kMaxBlocks * kProfSlots);
cudaMemcpy(host_ts.data(),
static_cast<char*>(workspace) + 1024,
host_ts.size() * sizeof(uint64_t),
@@ -1550,8 +1558,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
double sum = 0;
int n = 0;
for (int b = 0; b < kMaxBlocks; ++b) {
uint64_t lo = host_ts[b * 4 + idx_lo];
uint64_t hi = host_ts[b * 4 + idx_hi];
uint64_t lo = host_ts[b * kProfSlots + idx_lo];
uint64_t hi = host_ts[b * kProfSlots + idx_hi];
if (lo == 0 || hi <= lo) continue;
uint64_t d = hi - lo;
if (d < mn) mn = d;
@@ -1565,23 +1573,33 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
to_us(mx, sm_clock_khz),
n);
};
// Slot map: 0=entry, 1=send-done, 2=kernel-exit, 4..6=wg unpack-start, 8..10=wg unpack-end.
auto [send_mn, send_av, send_mx, send_n] = stats(0, 1);
auto [wait_mn, wait_av, wait_mx, wait_n] = stats(1, 2);
auto [unp_mn, unp_av, unp_mx, unp_n ] = stats(2, 3);
auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 3);
auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 2);
// Per-wg wait (1 -> wg start) and unpack (wg start -> wg end).
constexpr int kMaxWg = 3;
double wait_av[kMaxWg]={0}, wait_mx[kMaxWg]={0}, unp_av[kMaxWg]={0}, unp_mx[kMaxWg]={0};
int wait_n[kMaxWg]={0}, unp_n[kMaxWg]={0};
for (int wg = 0; wg < kMaxWg; ++wg) {
auto [_a, av_w, mx_w, n_w] = stats(1, 4 + wg);
auto [_b, av_u, mx_u, n_u] = stats(4 + wg, 8 + wg);
wait_av[wg]=av_w; wait_mx[wg]=mx_w; wait_n[wg]=n_w;
unp_av[wg]=av_u; unp_mx[wg]=mx_u; unp_n[wg]=n_u;
}
fprintf(stderr,
"[ep-prof dispatch #%d r%d] blocks=%d "
"send=%.1f/%.1f/%.1fus "
"wait=%.1f/%.1f/%.1fus "
"unpack=%.1f/%.1f/%.1fus "
"total=%.1f/%.1f/%.1fus (min/avg/max)\n",
call_idx++, rank, send_n,
"wait_wg(avg/max)=%.1f/%.1f, %.1f/%.1f, %.1f/%.1f us "
"unpack_wg(avg/max)=%.1f/%.1f, %.1f/%.1f, %.1f/%.1f us "
"total=%.1f/%.1f/%.1fus\n",
this_call, rank, send_n,
send_mn, send_av, send_mx,
wait_mn, wait_av, wait_mx,
unp_mn, unp_av, unp_mx,
wait_av[0], wait_mx[0], wait_av[1], wait_mx[1], wait_av[2], wait_mx[2],
unp_av[0], unp_mx[0], unp_av[1], unp_mx[1], unp_av[2], unp_mx[2],
tot_mn, tot_av, tot_mx);
cudaMemsetAsync(static_cast<char*>(workspace) + 1024,
0, host_ts.size() * sizeof(uint64_t), launch_stream);
}
}
#endif
@@ -1678,13 +1696,20 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
}();
if (kProfilePrint) {
static int call_idx = 0;
static const int kProfilePrintEveryC = []() {
const char* e = std::getenv("MSCCLPP_EP_LL_PROFILE_PRINT_EVERY");
int v = (e != nullptr) ? std::atoi(e) : 1;
return v > 0 ? v : 1;
}();
int this_call = call_idx++;
if (this_call % kProfilePrintEveryC == 0) {
cudaStreamSynchronize(launch_stream);
// Worst-case block count: combine launcher uses kNumWarpGroupsRdma=3, so
// num_sms = ceil(num_experts / 3). Read 1024 entries to be safe.
constexpr int kMaxBlocks = 1024;
std::vector<uint64_t> host_ts(kMaxBlocks * 4);
cudaMemcpy(host_ts.data(),
static_cast<char*>(workspace) + 65536,
static_cast<char*>(workspace) + 196608,
host_ts.size() * sizeof(uint64_t),
cudaMemcpyDeviceToHost);
static int sm_clock_khz = []() {
@@ -1726,14 +1751,15 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
"wait=%.1f/%.1f/%.1fus "
"reduce=%.1f/%.1f/%.1fus "
"total=%.1f/%.1f/%.1fus (min/avg/max)\n",
call_idx++, rank, send_n,
this_call, rank, send_n,
send_mn, send_av, send_mx,
wait_mn, wait_av, wait_mx,
redu_mn, redu_av, redu_mx,
tot_mn, tot_av, tot_mx);
// Zero out so next iter's "lo==0" filter rejects untouched slots.
cudaMemsetAsync(static_cast<char*>(workspace) + 65536,
cudaMemsetAsync(static_cast<char*>(workspace) + 196608,
0, host_ts.size() * sizeof(uint64_t), launch_stream);
}
}
#endif

View File

@@ -184,11 +184,14 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
// Profile timestamps. Reuse workspace tail at byte offset 1024 (well past
// the atomic_counter / atomic_finish_counter arrays, which use
// num_experts * 4 * 2 = 512 bytes for typical num_experts <= 64).
// Layout: [block][4]uint64_t.
// [b*4 + 0] = send phase entry
// [b*4 + 1] = send phase done (after count write barrier)
// [b*4 + 2] = recv-count spinwait done
// [b*4 + 3] = kernel done
// Per-block layout: 16 uint64_t slots (1024 blocks * 16 * 8 = 128 KiB).
// [b*16 + 0] = send phase entry
// [b*16 + 1] = send phase done (after count write barrier)
// [b*16 + 2] = kernel done
// [b*16 + 3] = spare
// [b*16 + 4 + wg] = warp_group `wg` unpack-start (recv-count ack)
// [b*16 + 8 + wg] = warp_group `wg` unpack-end (after for-loop)
constexpr int kProfSlots = 16;
uint64_t* prof_buf = reinterpret_cast<uint64_t*>(
reinterpret_cast<char*>(atomic_counter_per_expert) + 1024);
#endif
@@ -197,7 +200,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV;
#ifdef MSCCLPP_EP_LL_PROFILE
if (thread_id == 0) prof_buf[sm_id * 4 + 0] = clock64();
if (thread_id == 0) prof_buf[sm_id * kProfSlots + 0] = clock64();
#endif
// Expert counts
@@ -217,11 +220,12 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
auto dst_expert_idx =
warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// FP8 cast
// FP8 cast (or BF16 copy). Per-token staging into rdma_x[token_idx]. We
// intentionally do NOT bar.sync between tokens here -- a single barrier
// at the end of the per-token-loop covers all stages, halving the heavy
// 928-thread named barriers' contribution to the send phase.
#pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
auto int4_value = __ldg(x_int4 + i);
@@ -252,7 +256,20 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
}
}
asm volatile("bar.sync 1, %0;" ::"r"(num_threads));
}
// Single block-level barrier across the staging warps (warp_id <
// num_warps - 1, i.e. 928 threads). Replaces the previous per-token
// bar.sync that fired num_tokens_per_block times.
asm volatile("bar.sync 1, %0;" ::"r"(num_threads));
// Issue all sends for this block's tokens. Warps with `warp_id <
// num_topk` each map to one of the per-token top-k destinations and
// issue one IB write per token.
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
const auto rdma_x_src_idx =
reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
auto dst_expert_idx =
warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
// Issue sends
if (dst_expert_idx >= 0) {
@@ -397,7 +414,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
#ifdef MSCCLPP_EP_LL_PROFILE
__syncthreads();
if (thread_id == 0) prof_buf[sm_id * 4 + 1] = clock64();
if (thread_id == 0) prof_buf[sm_id * kProfSlots + 1] = clock64();
#endif
// Receiving phase
@@ -438,9 +455,9 @@ LOW_LATENCY_DISPATCH_RECV:
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
#ifdef MSCCLPP_EP_LL_PROFILE
// Per-block: timestamp once across warp_group_id 0 sub_warp_id 1.
if (warp_group_id == 0 && sub_warp_id == 1 && lane_id == 0) {
prof_buf[sm_id * 4 + 2] = clock64();
// Per-warp_group unpack-start (recv-count ack-received) timestamp.
if (sub_warp_id == 1 && lane_id == 0) {
prof_buf[sm_id * kProfSlots + 4 + warp_group_id] = clock64();
}
#endif
@@ -464,10 +481,17 @@ LOW_LATENCY_DISPATCH_RECV:
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
}
}
#ifdef MSCCLPP_EP_LL_PROFILE
// Per-warp_group unpack-end timestamp (one writer per wg).
asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
if (sub_warp_id == 0 && lane_id == 0) {
prof_buf[sm_id * kProfSlots + 8 + warp_group_id] = clock64();
}
#endif
}
#ifdef MSCCLPP_EP_LL_PROFILE
__syncthreads();
if (thread_id == 0) prof_buf[sm_id * 4 + 3] = clock64();
if (thread_id == 0) prof_buf[sm_id * kProfSlots + 2] = clock64();
#endif
}
@@ -600,8 +624,9 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
// [b*4 + 1] = send phase done (after trailing flag write barrier)
// [b*4 + 2] = recv-flag spinwait done
// [b*4 + 3] = kernel done
// Offset 196608 is past dispatch's prof_buf (1024 + 1024*16*8 = 132096).
uint64_t* prof_buf = reinterpret_cast<uint64_t*>(
reinterpret_cast<char*>(atomic_clean_flag) + 65536);
reinterpret_cast<char*>(atomic_clean_flag) + 196608);
#endif
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV;