diff --git a/src/ext/ep/kernels/internode_ncclep.cuh b/src/ext/ep/kernels/internode_ncclep.cuh index 008d98c7..46d91151 100644 --- a/src/ext/ep/kernels/internode_ncclep.cuh +++ b/src/ext/ep/kernels/internode_ncclep.cuh @@ -10,6 +10,16 @@ #define MSCCLPP_EP_INTERNODE_NCCLEP_CUH_ #ifdef EP_DISPATCH_NCCLEP +// DIAGNOSTIC PROBE (increment-3 de-risk): when 1, the NVL receiver keeps all +// control flow (index/head/tail) but SKIPS the actual data copies. This makes +// recv_x contents wrong for cross-GPU tokens (combine FAILs) but measures the +// dispatch-time UPPER BOUND of eliminating the cross-GPU receiver drain (the +// payoff ceiling of the full cross-GPU peer-map direct-write rework). Set to 0 +// for the real kernel. +#ifndef EP_NCCLEP_DRAIN_NOOP +#define EP_NCCLEP_DRAIN_NOOP 0 +#endif + template __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1) @@ -903,6 +913,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; +#if !EP_NCCLEP_DRAIN_NOOP // Copy data UNROLLED_WARP_COPY(28, lane_id, hidden_int4, recv_x + recv_token_idx * hidden_int4, nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, ld_nc_global, st_na_global); @@ -923,6 +934,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV static_cast(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx))); st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); } +#endif // !EP_NCCLEP_DRAIN_NOOP } // Move queue diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py index 97780595..3e0dc2f9 100644 --- a/test/python/ext/ep/test_internode_multirank.py +++ b/test/python/ext/ep/test_internode_multirank.py @@ -226,6 +226,7 @@ def main(): ) dist.barrier(group=group) + _skip_verify = os.environ.get("MSCCLPP_EP_SKIP_VERIFY","0") in ("1","true","True") # Validate recv buffer: for each source rank i, the block carries value i. assert recv_x.dim() == 2 and recv_x.size(1) == hidden start = 0 @@ -235,7 +236,7 @@ def main(): if block.numel(): lo = block.float().amin().item() hi = block.float().amax().item() - assert ( + assert _skip_verify or ( abs(lo - src) < 1e-3 and abs(hi - src) < 1e-3 ), f"rank{rank}: block from src={src} has range=[{lo}, {hi}], expected {src}" start = end @@ -285,7 +286,7 @@ def main(): # bf16 accumulator has 7-bit mantissa; intermediate partial sums can # round at ulp = max_exp * 2**-7. Use a tolerance that scales with magnitude. tol = max(1e-2, max_exp * (1.0 / 64)) - assert diff <= tol, f"rank{rank}: combine mismatch max diff {diff} > tol {tol} (max_exp={max_exp})" + assert _skip_verify or diff <= tol, f"rank{rank}: combine mismatch max diff {diff} > tol {tol} (max_exp={max_exp})" dist.barrier(group=group) if rank == 0: