diff --git a/src/ext/ep/CMakeLists.txt b/src/ext/ep/CMakeLists.txt index f1394a0e..237cf509 100644 --- a/src/ext/ep/CMakeLists.txt +++ b/src/ext/ep/CMakeLists.txt @@ -59,6 +59,13 @@ endif() Python_add_library(mscclpp_ep_cpp MODULE ${EP_SOURCES}) target_compile_definitions(mscclpp_ep_cpp PRIVATE TORCH_EXTENSION_NAME=mscclpp_ep_cpp) +# Optional: enable in-kernel timestamp profiling for low-latency combine. +# `cmake -DMSCCLPP_EP_LL_PROFILE=ON ...`. Runtime print is gated by +# the env var MSCCLPP_EP_LL_PROFILE_PRINT=1. +option(MSCCLPP_EP_LL_PROFILE "Enable in-kernel timestamp profiling for LL combine" OFF) +if(MSCCLPP_EP_LL_PROFILE) + target_compile_definitions(mscclpp_ep_cpp PRIVATE MSCCLPP_EP_LL_PROFILE) +endif() # Inherit ibverbs / mlx5dv defs from the core library so the IBGDA host-side # plumbing (see ibgda_setup.cc, gated by MSCCLPP_EP_HAVE_IBGDA) compiles in. if(IBVERBS_FOUND) diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index b13c8762..b17691a3 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -1522,6 +1522,69 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); +#ifdef MSCCLPP_EP_LL_PROFILE + // Read back per-block timestamps written by dispatch (offset 1024 in workspace). + static const bool kProfilePrintDisp = []() { + const char* e = std::getenv("MSCCLPP_EP_LL_PROFILE_PRINT"); + return e != nullptr && std::string(e) == "1"; + }(); + if (kProfilePrintDisp) { + static int call_idx = 0; + cudaStreamSynchronize(launch_stream); + constexpr int kMaxBlocks = 1024; + std::vector host_ts(kMaxBlocks * 4); + cudaMemcpy(host_ts.data(), + static_cast(workspace) + 1024, + host_ts.size() * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + static int sm_clock_khz = []() { + int dev = 0; cudaGetDevice(&dev); + int khz = 0; cudaDeviceGetAttribute(&khz, cudaDevAttrClockRate, dev); + return khz; + }(); + auto to_us = [](uint64_t ticks, int khz) { + return static_cast(ticks) * 1000.0 / static_cast(khz); + }; + auto stats = [&](int idx_lo, int idx_hi) { + uint64_t mn = ~0ull, mx = 0ull; + double sum = 0; + int n = 0; + for (int b = 0; b < kMaxBlocks; ++b) { + uint64_t lo = host_ts[b * 4 + idx_lo]; + uint64_t hi = host_ts[b * 4 + idx_hi]; + if (lo == 0 || hi <= lo) continue; + uint64_t d = hi - lo; + if (d < mn) mn = d; + if (d > mx) mx = d; + sum += d; + ++n; + } + if (n == 0) return std::make_tuple(0.0, 0.0, 0.0, 0); + return std::make_tuple(to_us(mn, sm_clock_khz), + to_us(static_cast(sum / n), sm_clock_khz), + to_us(mx, sm_clock_khz), + n); + }; + auto [send_mn, send_av, send_mx, send_n] = stats(0, 1); + auto [wait_mn, wait_av, wait_mx, wait_n] = stats(1, 2); + auto [unp_mn, unp_av, unp_mx, unp_n ] = stats(2, 3); + auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 3); + fprintf(stderr, + "[ep-prof dispatch #%d r%d] blocks=%d " + "send=%.1f/%.1f/%.1fus " + "wait=%.1f/%.1f/%.1fus " + "unpack=%.1f/%.1f/%.1fus " + "total=%.1f/%.1f/%.1fus (min/avg/max)\n", + call_idx++, rank, send_n, + send_mn, send_av, send_mx, + wait_mn, wait_av, wait_mx, + unp_mn, unp_av, unp_mx, + tot_mn, tot_av, tot_mx); + cudaMemsetAsync(static_cast(workspace) + 1024, + 0, host_ts.size() * sizeof(uint64_t), launch_stream); + } +#endif + std::optional event; if (async) { event = EventHandle(launch_stream); @@ -1605,6 +1668,75 @@ std::tuple, std::optional host_ts(kMaxBlocks * 4); + cudaMemcpy(host_ts.data(), + static_cast(workspace) + 65536, + host_ts.size() * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + static int sm_clock_khz = []() { + int dev = 0; cudaGetDevice(&dev); + int khz = 0; cudaDeviceGetAttribute(&khz, cudaDevAttrClockRate, dev); + return khz; + }(); + auto to_us = [](uint64_t ticks, int khz) { + // ticks / (khz/1000 ticks-per-us) = us, i.e. ticks * 1000 / khz. + return static_cast(ticks) * 1000.0 / static_cast(khz); + }; + auto stats = [&](int idx_lo, int idx_hi) { + uint64_t mn = ~0ull, mx = 0ull; + double sum = 0; + int n = 0; + for (int b = 0; b < kMaxBlocks; ++b) { + uint64_t lo = host_ts[b * 4 + idx_lo]; + uint64_t hi = host_ts[b * 4 + idx_hi]; + if (lo == 0 || hi <= lo) continue; + uint64_t d = hi - lo; + if (d < mn) mn = d; + if (d > mx) mx = d; + sum += d; + ++n; + } + if (n == 0) return std::make_tuple(0.0, 0.0, 0.0, 0); + return std::make_tuple(to_us(mn, sm_clock_khz), + to_us(static_cast(sum / n), sm_clock_khz), + to_us(mx, sm_clock_khz), + n); + }; + auto [send_mn, send_av, send_mx, send_n] = stats(0, 1); + auto [wait_mn, wait_av, wait_mx, wait_n] = stats(1, 2); + auto [redu_mn, redu_av, redu_mx, redu_n] = stats(2, 3); + auto [tot_mn, tot_av, tot_mx, tot_n ] = stats(0, 3); + fprintf(stderr, + "[ep-prof combine #%d r%d] blocks=%d " + "send=%.1f/%.1f/%.1fus " + "wait=%.1f/%.1f/%.1fus " + "reduce=%.1f/%.1f/%.1fus " + "total=%.1f/%.1f/%.1fus (min/avg/max)\n", + call_idx++, rank, send_n, + send_mn, send_av, send_mx, + wait_mn, wait_av, wait_mx, + redu_mn, redu_av, redu_mx, + tot_mn, tot_av, tot_mx); + // Zero out so next iter's "lo==0" filter rejects untouched slots. + cudaMemsetAsync(static_cast(workspace) + 65536, + 0, host_ts.size() * sizeof(uint64_t), launch_stream); + } +#endif + std::optional event; if (async) { event = EventHandle(launch_stream); diff --git a/src/ext/ep/kernels/internode_ll.cu b/src/ext/ep/kernels/internode_ll.cu index 49980ed0..cbe12cd4 100644 --- a/src/ext/ep/kernels/internode_ll.cu +++ b/src/ext/ep/kernels/internode_ll.cu @@ -180,9 +180,26 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); +#ifdef MSCCLPP_EP_LL_PROFILE + // Profile timestamps. Reuse workspace tail at byte offset 1024 (well past + // the atomic_counter / atomic_finish_counter arrays, which use + // num_experts * 4 * 2 = 512 bytes for typical num_experts <= 64). + // Layout: [block][4]uint64_t. + // [b*4 + 0] = send phase entry + // [b*4 + 1] = send phase done (after count write barrier) + // [b*4 + 2] = recv-count spinwait done + // [b*4 + 3] = kernel done + uint64_t* prof_buf = reinterpret_cast( + reinterpret_cast(atomic_counter_per_expert) + 1024); +#endif + // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; +#ifdef MSCCLPP_EP_LL_PROFILE + if (thread_id == 0) prof_buf[sm_id * 4 + 0] = clock64(); +#endif + // Expert counts __shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups]; @@ -378,6 +395,11 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dis } __syncwarp(); +#ifdef MSCCLPP_EP_LL_PROFILE + __syncthreads(); + if (thread_id == 0) prof_buf[sm_id * 4 + 1] = clock64(); +#endif + // Receiving phase LOW_LATENCY_DISPATCH_RECV: if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return; @@ -415,6 +437,13 @@ LOW_LATENCY_DISPATCH_RECV: num_recv_tokens = shared_num_recv_tokens[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; +#ifdef MSCCLPP_EP_LL_PROFILE + // Per-block: timestamp once across warp_group_id 0 sub_warp_id 1. + if (warp_group_id == 0 && sub_warp_id == 1 && lane_id == 0) { + prof_buf[sm_id * 4 + 2] = clock64(); + } +#endif + EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) { const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); @@ -436,6 +465,10 @@ LOW_LATENCY_DISPATCH_RECV: } } } +#ifdef MSCCLPP_EP_LL_PROFILE + __syncthreads(); + if (thread_id == 0) prof_buf[sm_id * 4 + 3] = clock64(); +#endif } void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, @@ -553,8 +586,30 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); +#ifdef MSCCLPP_EP_LL_PROFILE + // Profile timestamps. The same workspace is shared with the dispatch + // kernel, which uses bytes [0..2*num_experts*sizeof(int)] for its + // atomic_counter_per_expert / atomic_finish_counter_per_expert arrays + // (up to ~512 B for 64 experts). Dispatch's own prof_buf lives at + // byte offset 1024..(1024 + kMaxBlocks*32) = 1024..33792. Place the + // combine prof_buf well past that to avoid corrupting dispatch's + // atomic counters between iterations (which previously caused iter 1+ + // to hang in dispatch's FINISHED_SUM_TAG spinwait). + // Per-block [4]uint64_t scratch. Slots: + // [b*4 + 0] = send phase entry + // [b*4 + 1] = send phase done (after trailing flag write barrier) + // [b*4 + 2] = recv-flag spinwait done + // [b*4 + 3] = kernel done + uint64_t* prof_buf = reinterpret_cast( + reinterpret_cast(atomic_clean_flag) + 65536); +#endif + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV; +#ifdef MSCCLPP_EP_LL_PROFILE + if (thread_id == 0) prof_buf[sm_id * 4 + 0] = clock64(); +#endif + if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { #pragma unroll for (int i = lane_id; i < num_next_clean_int; i += 32) next_clean[i] = 0; @@ -666,6 +721,11 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com __syncwarp(); } +#ifdef MSCCLPP_EP_LL_PROFILE + __syncthreads(); + if (thread_id == 0) prof_buf[sm_id * 4 + 1] = clock64(); +#endif + LOW_LATENCY_COMBINE_RECV: if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return; @@ -675,6 +735,10 @@ LOW_LATENCY_COMBINE_RECV: while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0) ; } +#ifdef MSCCLPP_EP_LL_PROFILE + __syncthreads(); + if (thread_id == 0) prof_buf[sm_id * 4 + 2] = clock64(); +#endif cg::this_grid().sync(); EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads); @@ -712,6 +776,10 @@ LOW_LATENCY_COMBINE_RECV: (reinterpret_cast(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; } } +#ifdef MSCCLPP_EP_LL_PROFILE + __syncthreads(); + if (thread_id == 0) prof_buf[sm_id * 4 + 3] = clock64(); +#endif } void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, const void* x,