mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
ext/ep: lift cached_notify nc cap, strided-warp head fixup (HT 442/399 GB/s)
Two latent bugs in internode cached_notify blocked num_channels scaling: 1) Launch thread count exceeded per-block limit. num_threads = max(128, 32 * num_channels) produced >1024 at nc>=33, above GB200 cudaDevAttrMaxThreadsPerBlock=1024. cudaLaunchKernelEx returned cudaErrorInvalidValue, which silently corrupted the buffer-clean path at nc<=64 and hard-failed at nc>=68. Cap launch at 1024 and decouple sizing from work distribution. 2) Hardcoded warp_id < num_channels in sm_id>=2 branches. The head-fixup work required num_warps >= num_channels. Replaced with a strided per-warp loop (for ch = warp_id; ch < num_channels; ch += num_warps) so any thread count covers any channel count. With both fixes, num_channels scales from 20 to 152 (= SM count, the cooperative-grid ceiling at 2x SMs). HT throughput grows from 78/94 GB/s (dispatch/combine) at nc=20 to 442/399 GB/s at nc=152, a 5.7x / 4.2x speedup. Gap to LL closes from 9.0x / 7.6x down to 1.6x / 1.8x. No effect on correctness at nc<=32 (path was always exercising num_warps >= num_channels there); changes are purely additive at low nc and remove the silent-failure ceiling at high nc.
This commit is contained in:
@@ -1550,6 +1550,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
auto num_threads = static_cast<int>(blockDim.x);
|
||||
auto num_warps = num_threads / 32;
|
||||
auto warp_id = thread_id / 32;
|
||||
auto lane_id = get_lane_id();
|
||||
|
||||
@@ -1639,49 +1640,52 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
} else if (sm_id == 2) {
|
||||
if (is_cached_dispatch) return;
|
||||
|
||||
EP_DEVICE_ASSERT(num_warps >= num_channels);
|
||||
EP_DEVICE_ASSERT(num_rdma_ranks <= 32);
|
||||
|
||||
// Iterate in reverse order
|
||||
if (lane_id < num_rdma_ranks and warp_id < num_channels) {
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx);
|
||||
// Iterate in reverse order. Stride over channels in warp granularity so we
|
||||
// support num_channels > num_warps (per-block thread cap on GB200 is 1024).
|
||||
if (lane_id < num_rdma_ranks) {
|
||||
for (int ch = warp_id; ch < num_channels; ch += num_warps) {
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_combined_tokens, num_channels, ch, token_start_idx, token_end_idx);
|
||||
|
||||
// NOTES: `1 << 25` is a heuristic large number
|
||||
int last_head = 1 << 25;
|
||||
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
|
||||
auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
|
||||
if (current_head < 0) {
|
||||
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
|
||||
} else {
|
||||
last_head = current_head;
|
||||
// NOTES: `1 << 25` is a heuristic large number
|
||||
int last_head = 1 << 25;
|
||||
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
|
||||
auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
|
||||
if (current_head < 0) {
|
||||
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
|
||||
} else {
|
||||
last_head = current_head;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (is_cached_dispatch) return;
|
||||
|
||||
EP_DEVICE_ASSERT(num_warps >= num_channels);
|
||||
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr);
|
||||
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers");
|
||||
|
||||
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
|
||||
// Stride over channels per warp so num_channels > num_warps works.
|
||||
if (lane_id < NUM_MAX_NVL_PEERS) {
|
||||
for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 3) {
|
||||
// Iterate in reverse order
|
||||
int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
|
||||
int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
|
||||
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
|
||||
token_start_idx += shift, token_end_idx += shift;
|
||||
for (int ch = warp_id; ch < num_channels; ch += num_warps) {
|
||||
// Iterate in reverse order
|
||||
int token_start_idx = ch == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + ch - 1];
|
||||
int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + ch];
|
||||
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
|
||||
token_start_idx += shift, token_end_idx += shift;
|
||||
|
||||
// NOTES: `1 << 25` is a heuristic large number
|
||||
int last_head = 1 << 25;
|
||||
#pragma unroll
|
||||
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
|
||||
auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
|
||||
if (current_head < 0) {
|
||||
combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
|
||||
} else {
|
||||
last_head = current_head;
|
||||
// NOTES: `1 << 25` is a heuristic large number
|
||||
int last_head = 1 << 25;
|
||||
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
|
||||
auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
|
||||
if (current_head < 0) {
|
||||
combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
|
||||
} else {
|
||||
last_head = current_head;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1712,9 +1716,16 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
|
||||
EP_HOST_ASSERT(num_channels * 2 > 3);
|
||||
|
||||
// Launch kernel
|
||||
// Launch kernel.
|
||||
// NOTE: cached_notify's sm_id>=2 work iterates `warp_id < num_channels`, so the
|
||||
// kernel originally requested 32*num_channels threads per block. On GB200 the
|
||||
// per-block thread limit is 1024 (= 32 warps), so nc>32 silently failed launch
|
||||
// (cudaLaunchKernelEx returns cudaErrorInvalidValue / "invalid argument").
|
||||
// Cap at 1024 here and add a strided fallback in the kernel (warps loop over
|
||||
// channels in stride-num_warps) so nc up to the SM-count is supported.
|
||||
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
|
||||
SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
|
||||
const int launch_threads = std::min(num_threads, 1024);
|
||||
SETUP_LAUNCH_CONFIG(num_channels * 2, launch_threads, stream);
|
||||
LAUNCH_KERNEL(&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second, nvl_clean_meta.first,
|
||||
nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, num_channels,
|
||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, buffer_ptrs,
|
||||
|
||||
Reference in New Issue
Block a user