diff --git a/src/ext/ep/kernels/internode.cu b/src/ext/ep/kernels/internode.cu index b64205e0..5c25b0fc 100644 --- a/src/ext/ep/kernels/internode.cu +++ b/src/ext/ep/kernels/internode.cu @@ -1550,6 +1550,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x); auto num_threads = static_cast(blockDim.x); + auto num_warps = num_threads / 32; auto warp_id = thread_id / 32; auto lane_id = get_lane_id(); @@ -1639,49 +1640,52 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in } else if (sm_id == 2) { if (is_cached_dispatch) return; - EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(num_rdma_ranks <= 32); - // Iterate in reverse order - if (lane_id < num_rdma_ranks and warp_id < num_channels) { - int token_start_idx, token_end_idx; - get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx); + // Iterate in reverse order. Stride over channels in warp granularity so we + // support num_channels > num_warps (per-block thread cap on GB200 is 1024). + if (lane_id < num_rdma_ranks) { + for (int ch = warp_id; ch < num_channels; ch += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, ch, token_start_idx, token_end_idx); - // NOTES: `1 << 25` is a heuristic large number - int last_head = 1 << 25; - for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { - auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); - if (current_head < 0) { - combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; - } else { - last_head = current_head; + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { + auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); + if (current_head < 0) { + combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } } } } } else { if (is_cached_dispatch) return; - EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); - if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) { + // Stride over channels per warp so num_channels > num_warps works. + if (lane_id < NUM_MAX_NVL_PEERS) { for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 3) { - // Iterate in reverse order - int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; - int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; - int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; - token_start_idx += shift, token_end_idx += shift; + for (int ch = warp_id; ch < num_channels; ch += num_warps) { + // Iterate in reverse order + int token_start_idx = ch == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + ch - 1]; + int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + ch]; + int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; + token_start_idx += shift, token_end_idx += shift; - // NOTES: `1 << 25` is a heuristic large number - int last_head = 1 << 25; -#pragma unroll - for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { - auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); - if (current_head < 0) { - combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; - } else { - last_head = current_head; + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { + auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); + if (current_head < 0) { + combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } } } } @@ -1712,9 +1716,16 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); EP_HOST_ASSERT(num_channels * 2 > 3); - // Launch kernel + // Launch kernel. + // NOTE: cached_notify's sm_id>=2 work iterates `warp_id < num_channels`, so the + // kernel originally requested 32*num_channels threads per block. On GB200 the + // per-block thread limit is 1024 (= 32 warps), so nc>32 silently failed launch + // (cudaLaunchKernelEx returns cudaErrorInvalidValue / "invalid argument"). + // Cap at 1024 here and add a strided fallback in the kernel (warps loop over + // channels in stride-num_warps) so nc up to the SM-count is supported. auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; - SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); + const int launch_threads = std::min(num_threads, 1024); + SETUP_LAUNCH_CONFIG(num_channels * 2, launch_threads, stream); LAUNCH_KERNEL(&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second, nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, buffer_ptrs,