mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
ext/ep: per-warp_group dispatch profile slots + sampled readback
Extend the LL dispatch profile from 4 to 16 per-block uint64_t slots so we can observe per-warp_group wait and unpack times separately: [b*16 + 0] send phase entry [b*16 + 1] send phase done [b*16 + 2] kernel done [b*16 + 4 + wg] warp_group wg unpack-start (recv-count ack received) [b*16 + 8 + wg] warp_group wg unpack-end (after for-loop) This decomposes the previous lumped 'unpack' (kernel_exit minus last_wg_ack) into the actual per-warp_group copy time and the per- warp_group network wait, exposing that the copy itself is ~6 us / wg and the bulk of the time is network arrival jitter. Move the combine prof_buf base from offset 65536 to 196608 so it sits past the dispatch prof_buf (1024 + 1024*16*8 = 132096), preserving the no-collision guarantee with dispatch's atomic counters. Add MSCCLPP_EP_LL_PROFILE_PRINT_EVERY=N (default 1) so readback / cudaMemcpy / cudaMemset only run every Nth call. With EVERY=29 over a 30-iter bench (10 warmup + 20 timed) the readback fires once and the PROFILE-ON BW penalty drops from ~10% to ~3% vs PROFILE-OFF, making the instrumentation usable during real benchmarks.
This commit is contained in:
@@ -1528,11 +1528,19 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
const char* e = std::getenv("MSCCLPP_EP_LL_PROFILE_PRINT");
|
||||
return e != nullptr && std::string(e) == "1";
|
||||
}();
|
||||
static const int kProfilePrintEvery = []() {
|
||||
const char* e = std::getenv("MSCCLPP_EP_LL_PROFILE_PRINT_EVERY");
|
||||
int v = (e != nullptr) ? std::atoi(e) : 1;
|
||||
return v > 0 ? v : 1;
|
||||
}();
|
||||
if (kProfilePrintDisp) {
|
||||
static int call_idx = 0;
|
||||
int this_call = call_idx++;
|
||||
if (this_call % kProfilePrintEvery == 0) {
|
||||
cudaStreamSynchronize(launch_stream);
|
||||
constexpr int kMaxBlocks = 1024;
|
||||
std::vector<uint64_t> host_ts(kMaxBlocks * 4);
|
||||
constexpr int kProfSlots = 16;
|
||||
std::vector<uint64_t> host_ts(kMaxBlocks * kProfSlots);
|
||||
cudaMemcpy(host_ts.data(),
|
||||
static_cast<char*>(workspace) + 1024,
|
||||
host_ts.size() * sizeof(uint64_t),
|
||||
@@ -1550,8 +1558,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
double sum = 0;
|
||||
int n = 0;
|
||||
for (int b = 0; b < kMaxBlocks; ++b) {
|
||||
uint64_t lo = host_ts[b * 4 + idx_lo];
|
||||
uint64_t hi = host_ts[b * 4 + idx_hi];
|
||||
uint64_t lo = host_ts[b * kProfSlots + idx_lo];
|
||||
uint64_t hi = host_ts[b * kProfSlots + idx_hi];
|
||||
if (lo == 0 || hi <= lo) continue;
|
||||
uint64_t d = hi - lo;
|
||||
if (d < mn) mn = d;
|
||||
@@ -1565,23 +1573,33 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
to_us(mx, sm_clock_khz),
|
||||
n);
|
||||
};
|
||||
// Slot map: 0=entry, 1=send-done, 2=kernel-exit, 4..6=wg unpack-start, 8..10=wg unpack-end.
|
||||
auto [send_mn, send_av, send_mx, send_n] = stats(0, 1);
|
||||
auto [wait_mn, wait_av, wait_mx, wait_n] = stats(1, 2);
|
||||
auto [unp_mn, unp_av, unp_mx, unp_n ] = stats(2, 3);
|
||||
auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 3);
|
||||
auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 2);
|
||||
// Per-wg wait (1 -> wg start) and unpack (wg start -> wg end).
|
||||
constexpr int kMaxWg = 3;
|
||||
double wait_av[kMaxWg]={0}, wait_mx[kMaxWg]={0}, unp_av[kMaxWg]={0}, unp_mx[kMaxWg]={0};
|
||||
int wait_n[kMaxWg]={0}, unp_n[kMaxWg]={0};
|
||||
for (int wg = 0; wg < kMaxWg; ++wg) {
|
||||
auto [_a, av_w, mx_w, n_w] = stats(1, 4 + wg);
|
||||
auto [_b, av_u, mx_u, n_u] = stats(4 + wg, 8 + wg);
|
||||
wait_av[wg]=av_w; wait_mx[wg]=mx_w; wait_n[wg]=n_w;
|
||||
unp_av[wg]=av_u; unp_mx[wg]=mx_u; unp_n[wg]=n_u;
|
||||
}
|
||||
fprintf(stderr,
|
||||
"[ep-prof dispatch #%d r%d] blocks=%d "
|
||||
"send=%.1f/%.1f/%.1fus "
|
||||
"wait=%.1f/%.1f/%.1fus "
|
||||
"unpack=%.1f/%.1f/%.1fus "
|
||||
"total=%.1f/%.1f/%.1fus (min/avg/max)\n",
|
||||
call_idx++, rank, send_n,
|
||||
"wait_wg(avg/max)=%.1f/%.1f, %.1f/%.1f, %.1f/%.1f us "
|
||||
"unpack_wg(avg/max)=%.1f/%.1f, %.1f/%.1f, %.1f/%.1f us "
|
||||
"total=%.1f/%.1f/%.1fus\n",
|
||||
this_call, rank, send_n,
|
||||
send_mn, send_av, send_mx,
|
||||
wait_mn, wait_av, wait_mx,
|
||||
unp_mn, unp_av, unp_mx,
|
||||
wait_av[0], wait_mx[0], wait_av[1], wait_mx[1], wait_av[2], wait_mx[2],
|
||||
unp_av[0], unp_mx[0], unp_av[1], unp_mx[1], unp_av[2], unp_mx[2],
|
||||
tot_mn, tot_av, tot_mx);
|
||||
cudaMemsetAsync(static_cast<char*>(workspace) + 1024,
|
||||
0, host_ts.size() * sizeof(uint64_t), launch_stream);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1678,13 +1696,20 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
}();
|
||||
if (kProfilePrint) {
|
||||
static int call_idx = 0;
|
||||
static const int kProfilePrintEveryC = []() {
|
||||
const char* e = std::getenv("MSCCLPP_EP_LL_PROFILE_PRINT_EVERY");
|
||||
int v = (e != nullptr) ? std::atoi(e) : 1;
|
||||
return v > 0 ? v : 1;
|
||||
}();
|
||||
int this_call = call_idx++;
|
||||
if (this_call % kProfilePrintEveryC == 0) {
|
||||
cudaStreamSynchronize(launch_stream);
|
||||
// Worst-case block count: combine launcher uses kNumWarpGroupsRdma=3, so
|
||||
// num_sms = ceil(num_experts / 3). Read 1024 entries to be safe.
|
||||
constexpr int kMaxBlocks = 1024;
|
||||
std::vector<uint64_t> host_ts(kMaxBlocks * 4);
|
||||
cudaMemcpy(host_ts.data(),
|
||||
static_cast<char*>(workspace) + 65536,
|
||||
static_cast<char*>(workspace) + 196608,
|
||||
host_ts.size() * sizeof(uint64_t),
|
||||
cudaMemcpyDeviceToHost);
|
||||
static int sm_clock_khz = []() {
|
||||
@@ -1726,14 +1751,15 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
"wait=%.1f/%.1f/%.1fus "
|
||||
"reduce=%.1f/%.1f/%.1fus "
|
||||
"total=%.1f/%.1f/%.1fus (min/avg/max)\n",
|
||||
call_idx++, rank, send_n,
|
||||
this_call, rank, send_n,
|
||||
send_mn, send_av, send_mx,
|
||||
wait_mn, wait_av, wait_mx,
|
||||
redu_mn, redu_av, redu_mx,
|
||||
tot_mn, tot_av, tot_mx);
|
||||
// Zero out so next iter's "lo==0" filter rejects untouched slots.
|
||||
cudaMemsetAsync(static_cast<char*>(workspace) + 65536,
|
||||
cudaMemsetAsync(static_cast<char*>(workspace) + 196608,
|
||||
0, host_ts.size() * sizeof(uint64_t), launch_stream);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -184,11 +184,14 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
|
||||
// Profile timestamps. Reuse workspace tail at byte offset 1024 (well past
|
||||
// the atomic_counter / atomic_finish_counter arrays, which use
|
||||
// num_experts * 4 * 2 = 512 bytes for typical num_experts <= 64).
|
||||
// Layout: [block][4]uint64_t.
|
||||
// [b*4 + 0] = send phase entry
|
||||
// [b*4 + 1] = send phase done (after count write barrier)
|
||||
// [b*4 + 2] = recv-count spinwait done
|
||||
// [b*4 + 3] = kernel done
|
||||
// Per-block layout: 16 uint64_t slots (1024 blocks * 16 * 8 = 128 KiB).
|
||||
// [b*16 + 0] = send phase entry
|
||||
// [b*16 + 1] = send phase done (after count write barrier)
|
||||
// [b*16 + 2] = kernel done
|
||||
// [b*16 + 3] = spare
|
||||
// [b*16 + 4 + wg] = warp_group `wg` unpack-start (recv-count ack)
|
||||
// [b*16 + 8 + wg] = warp_group `wg` unpack-end (after for-loop)
|
||||
constexpr int kProfSlots = 16;
|
||||
uint64_t* prof_buf = reinterpret_cast<uint64_t*>(
|
||||
reinterpret_cast<char*>(atomic_counter_per_expert) + 1024);
|
||||
#endif
|
||||
@@ -197,7 +200,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
|
||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV;
|
||||
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
if (thread_id == 0) prof_buf[sm_id * 4 + 0] = clock64();
|
||||
if (thread_id == 0) prof_buf[sm_id * kProfSlots + 0] = clock64();
|
||||
#endif
|
||||
|
||||
// Expert counts
|
||||
@@ -217,11 +220,12 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
|
||||
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
|
||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
|
||||
|
||||
auto dst_expert_idx =
|
||||
warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
|
||||
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
|
||||
|
||||
// FP8 cast
|
||||
// FP8 cast (or BF16 copy). Per-token staging into rdma_x[token_idx]. We
|
||||
// intentionally do NOT bar.sync between tokens here -- a single barrier
|
||||
// at the end of the per-token-loop covers all stages, halving the heavy
|
||||
// 928-thread named barriers' contribution to the send phase.
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
|
||||
auto int4_value = __ldg(x_int4 + i);
|
||||
@@ -252,7 +256,20 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
|
||||
rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
|
||||
}
|
||||
}
|
||||
asm volatile("bar.sync 1, %0;" ::"r"(num_threads));
|
||||
}
|
||||
// Single block-level barrier across the staging warps (warp_id <
|
||||
// num_warps - 1, i.e. 928 threads). Replaces the previous per-token
|
||||
// bar.sync that fired num_tokens_per_block times.
|
||||
asm volatile("bar.sync 1, %0;" ::"r"(num_threads));
|
||||
|
||||
// Issue all sends for this block's tokens. Warps with `warp_id <
|
||||
// num_topk` each map to one of the per-token top-k destinations and
|
||||
// issue one IB write per token.
|
||||
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
|
||||
const auto rdma_x_src_idx =
|
||||
reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
|
||||
auto dst_expert_idx =
|
||||
warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
|
||||
|
||||
// Issue sends
|
||||
if (dst_expert_idx >= 0) {
|
||||
@@ -397,7 +414,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis
|
||||
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
__syncthreads();
|
||||
if (thread_id == 0) prof_buf[sm_id * 4 + 1] = clock64();
|
||||
if (thread_id == 0) prof_buf[sm_id * kProfSlots + 1] = clock64();
|
||||
#endif
|
||||
|
||||
// Receiving phase
|
||||
@@ -438,9 +455,9 @@ LOW_LATENCY_DISPATCH_RECV:
|
||||
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
|
||||
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
// Per-block: timestamp once across warp_group_id 0 sub_warp_id 1.
|
||||
if (warp_group_id == 0 && sub_warp_id == 1 && lane_id == 0) {
|
||||
prof_buf[sm_id * 4 + 2] = clock64();
|
||||
// Per-warp_group unpack-start (recv-count ack-received) timestamp.
|
||||
if (sub_warp_id == 1 && lane_id == 0) {
|
||||
prof_buf[sm_id * kProfSlots + 4 + warp_group_id] = clock64();
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -464,10 +481,17 @@ LOW_LATENCY_DISPATCH_RECV:
|
||||
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
|
||||
}
|
||||
}
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
// Per-warp_group unpack-end timestamp (one writer per wg).
|
||||
asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
|
||||
if (sub_warp_id == 0 && lane_id == 0) {
|
||||
prof_buf[sm_id * kProfSlots + 8 + warp_group_id] = clock64();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef MSCCLPP_EP_LL_PROFILE
|
||||
__syncthreads();
|
||||
if (thread_id == 0) prof_buf[sm_id * 4 + 3] = clock64();
|
||||
if (thread_id == 0) prof_buf[sm_id * kProfSlots + 2] = clock64();
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -600,8 +624,9 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com
|
||||
// [b*4 + 1] = send phase done (after trailing flag write barrier)
|
||||
// [b*4 + 2] = recv-flag spinwait done
|
||||
// [b*4 + 3] = kernel done
|
||||
// Offset 196608 is past dispatch's prof_buf (1024 + 1024*16*8 = 132096).
|
||||
uint64_t* prof_buf = reinterpret_cast<uint64_t*>(
|
||||
reinterpret_cast<char*>(atomic_clean_flag) + 65536);
|
||||
reinterpret_cast<char*>(atomic_clean_flag) + 196608);
|
||||
#endif
|
||||
|
||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV;
|
||||
|
||||
Reference in New Issue
Block a user