diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index c7a3e163..74aa7922 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -1228,12 +1228,16 @@ Buffer::internode_dispatch( recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); // Just a barrier and clean flags + if (nvls_ht_enabled) ++nvls_ht_cached_epoch; internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr, nullptr, nullptr, nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, true, low_latency_mode, port_channel_handles_device_ptr.get(), - memory_channel_handles_device_ptr.get()); + memory_channel_handles_device_ptr.get(), + nvls_ht_enabled ? nvls_ht_mc_ptr : nullptr, + nvls_ht_enabled ? nvls_ht_dev_ptr : nullptr, + nvls_ht_off_barrier, nvls_ht_cached_epoch); move_fifo_slots(2); } else { rdma_channel_prefix_matrix = @@ -1459,7 +1463,10 @@ std::tuple, std::optional(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode, - port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get()); + port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get(), + nvls_ht_enabled ? nvls_ht_mc_ptr : nullptr, + nvls_ht_enabled ? nvls_ht_dev_ptr : nullptr, + nvls_ht_off_barrier, (nvls_ht_enabled ? ++nvls_ht_cached_epoch : 0)); move_fifo_slots(2); auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); diff --git a/src/ext/ep/buffer.hpp b/src/ext/ep/buffer.hpp index 06f8f912..36453af5 100644 --- a/src/ext/ep/buffer.hpp +++ b/src/ext/ep/buffer.hpp @@ -147,6 +147,11 @@ struct Buffer { // before each kernel launch that uses an NVLS barrier; the kernel spins // until the barrier slot reaches `epoch * num_ranks`. uint64_t nvls_ht_epoch = 0; + // Independent epoch for cached_notify barrier slots (offsets +24 / +32), + // since those slots are only touched when the cached path is taken — using + // the shared `nvls_ht_epoch` would over-count the expected value relative + // to the number of times those particular slots have actually been bumped. + uint64_t nvls_ht_cached_epoch = 0; // Worst-case shape parameters used to size the buffer: // stride_per_channel = num_rdma_ranks * num_rdma_ranks (counter slots) // We allocate for `kNvlsMaxChannels` so any `num_sms` config fits. diff --git a/src/ext/ep/kernels/api.cuh b/src/ext/ep/kernels/api.cuh index cd0d1764..d728faa9 100644 --- a/src/ext/ep/kernels/api.cuh +++ b/src/ext/ep/kernels/api.cuh @@ -104,7 +104,9 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode, mscclpp::PortChannelDeviceHandle* port_channel_handles, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles); + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + void* nvls_mc_ptr = nullptr, void* nvls_dev_ptr = nullptr, + size_t nvls_off_barrier = 0, uint64_t nvls_epoch = 0); void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, const void* x, const float* topk_weights, const int* combined_rdma_head, const int* combined_nvl_head, diff --git a/src/ext/ep/kernels/internode.cu b/src/ext/ep/kernels/internode.cu index 2b6b42c5..70124e38 100644 --- a/src/ext/ep/kernels/internode.cu +++ b/src/ext/ep/kernels/internode.cu @@ -880,21 +880,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV if (is_token_in_rank_uint64 != 0) { rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++; while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { - // Phase 4: prefer the legacy local-load when peer_rdma_bases - // is in use \u2014 cross-node head feedback now writes the - // absolute value directly into our local rdma_channel_head - // slot via fabric VA, so a `ld_volatile_global` here sees it. - if (peer_rdma_bases != nullptr && lane_id != rdma_rank) { - cached_rdma_channel_head = - static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); - } else if (nvls_head_dev != nullptr) { - // Phase 3: NVLS fast path (self loop only when fabric VA on). - cached_rdma_channel_head = - static_cast(nvls_ctr_load(nvls_head_dev, channel_id, rdma_rank, lane_id)); - } else { - cached_rdma_channel_head = - static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); - } + // Phase 4: head feedback path \u2014 cross-node uses fabric-VA store, + // self-loop uses local atomic. Both end up in rdma_channel_head; + // single read via ld_volatile_global covers them. NVLS removed. + cached_rdma_channel_head = + static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); } } __syncwarp(); @@ -1072,11 +1062,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV } // Owner advances tail counter for this peer. // Phase 4 fix: cross-node tail goes through direct fabric-VA - // store on peer's rdma_channel_tail slot (single writer, no - // atomic needed). The NVLS multimem.red.relaxed counter path - // is unreliable for cross-node visibility — empirically only - // partial increments are seen by the consumer's ld.acquire. - // Self path (dst == self) still uses NVLS for intra-rank. + // store on peer's rdma_channel_tail slot. Self-loop tail goes + // through plain local atomicAdd on rdma_channel_tail.buffer(rdma_rank) + // — NVLS multicast is WRONG for self-loop because it fans out + // to all bound NVL peers' buffers (4 NVL ranks × n_issue ⇒ 4x + // over-count on each consumer's read). if (owner) { if (peer_rdma_bases != nullptr && dst_rdma_rank != rdma_rank) { const int dst_rank_global = dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank; @@ -1087,12 +1077,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV reinterpret_cast(peer_rdma_bases[dst_rank_global]) + my_tail_off); const uint64_t new_tail = (uint64_t)issue_tail + (uint64_t)n_issue; asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_tail), "l"(new_tail) : "memory"); - if (rank == 0) { - printf("[ph4-T] sender wrote tail rank=%d ch=%d dst=%d new_tail=%lu peer=%p\n", - rank, channel_id, dst_rdma_rank, (unsigned long)new_tail, (void*)peer_tail); - } } else { - nvls_ctr_add(nvls_tail_mc, channel_id, rdma_rank, dst_rdma_rank, (uint64_t)n_issue); + // Self-loop: plain release atomic on local slot (no multicast). + mscclpp::atomicFetchAdd( + reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), + (uint64_t)n_issue, mscclpp::memoryOrderRelease); } } } else if (owner) { @@ -1230,20 +1219,12 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV // fabric-VA store to local rdma_channel_tail.buffer(src_rdma_rank), // so prefer the legacy ld_acquire read which sees those stores. // NVLS counter only used for self path (single rank). - if (peer_rdma_bases != nullptr && src_rdma_rank != rdma_rank) { - cached_rdma_channel_tail = - static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); - if (rank == 4 && lane_id == src_rdma_rank && cached_rdma_channel_tail > cached_rdma_channel_head) { - printf("[ph4-Tr] fwd read tail rank=%d ch=%d src=%d val=%d head=%d\n", - rank, channel_id, src_rdma_rank, cached_rdma_channel_tail, cached_rdma_channel_head); - } - } else if (nvls_tail_dev != nullptr) { - uint64_t v = nvls_ctr_load(nvls_tail_dev, channel_id, src_rdma_rank, rdma_rank); - cached_rdma_channel_tail = static_cast(v); - } else { - cached_rdma_channel_tail = - static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); - } + // Phase 4 fix: cross-node tail comes via direct fabric-VA store + // (sender writes peer's rdma_channel_tail slot). Self-loop tail + // is a plain local atomic. Both end up in rdma_channel_tail — + // single read path via ld_acquire_sys_global covers them. + cached_rdma_channel_tail = + static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); } if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) break; } @@ -1261,6 +1242,12 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank); auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); + if (rank == 4 && lane_id == 0 && dst_nvl_rank == 0 && channel_id == 0) { + printf("[ph4-Loop] fwd loop rank=%d ch=%d dst_nvl=%d src_rdma=%d head=%d tail=%d nrecv=%d\n", + rank, channel_id, dst_nvl_rank, src_rdma_rank, src_rdma_head, src_rdma_tail, + num_tokens_to_recv_from_rdma); + } + // Iterate over every token from the RDMA buffer for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) { auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; @@ -1324,7 +1311,12 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV // Retired __syncwarp(); - if (lane_id == 0) forward_channel_retired[dst_nvl_rank] = true; + if (lane_id == 0) { + forward_channel_retired[dst_nvl_rank] = true; + if (channel_id == 0) { + printf("[ph4-Fr] fwd retire rank=%d ch=%d dst_nvl=%d\n", rank, channel_id, dst_nvl_rank); + } + } } else if (warp_role == WarpRole::kForwarderCoordinator) { // Extra warps for forwarder coordinator should exit directly if (target_rank > 0) return; @@ -1365,14 +1357,10 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV uint64_t* peer_head = reinterpret_cast( reinterpret_cast(peer_rdma_bases[dst_rank_global]) + my_head_off); asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_head), "l"((uint64_t)min_head) : "memory"); - } else if (nvls_head_mc != nullptr) { - // Phase 3: NVLS counter fast path. Slot keyed (producer=lane_id, - // consumer=rdma_rank). Same coordinate as reader-side - // `nvls_ctr_load(nvls_head_dev, channel, rdma_rank, lane_id)` — - // (P, C) pair is canonical regardless of who reads/writes. - nvls_ctr_add(nvls_head_mc, channel_id, lane_id, rdma_rank, - (uint64_t)(min_head - last_head)); } else if (lane_id == rdma_rank) { + // Self-loop: plain release atomic on local slot. Cannot use NVLS + // multimem here \u2014 it fans out to all NVL peers' local buffers + // and over-counts (4 NVL ranks \u00d7 add \u21d2 4x increment). mscclpp::atomicFetchAdd(static_cast(rdma_channel_head.buffer(rdma_rank)), (uint64_t)(min_head - last_head), mscclpp::memoryOrderRelease); } else { @@ -1487,7 +1475,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV if (lane_id == 0) st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); } } - if (thread_id == 0 && channel_id == 0 && rank == 0) { + if (thread_id == 0 && channel_id == 0) { printf("[ph4-Z] dispatch exit rank=%d sm=%d\n", rank, sm_id); } } @@ -1542,7 +1530,12 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in int* combined_nvl_head, void* rdma_buffer_ptr, void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks, bool is_cached_dispatch, mscclpp::PortChannelDeviceHandle* port_channel_handles, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + // Phase 4: NVLS multimem barrier path — bypasses port_channel signal/wait + // (broken on Azure CX-7 RoCE) by using multimem.red.add + ld.acquire on a + // pair of barrier slots (offsets +24 and +32). Both nullptr ⇒ fall back to IB. + void* nvls_mc_ptr, void* nvls_dev_ptr, size_t nvls_off_barrier, + uint64_t nvls_epoch) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x); auto num_threads = static_cast(blockDim.x); @@ -1555,16 +1548,33 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in // Using two SMs, which clean the RDMA/NVL buffer respectively if (sm_id == 0) { - // Barrier for RDMA - - // TODO(chhwang): it should be a global barrier when kLowLatencyMode is false - const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank); - const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank); - if (run_barrier) { - port_channel_handles[barrier_channel_idx].signal(); - port_channel_handles[barrier_channel_idx].wait(); + if (thread_id == 0) { + printf("[ph4-CN0] cached_notify sm0 entry rank=%d\n", rank); + } + // Barrier for RDMA — Phase 4: NVLS multimem.red.add fast path replaces + // the port_channel signal/wait pair (broken on Azure CX-7). + if (nvls_mc_ptr != nullptr) { + if (thread_id == 0) { + uint64_t* mc_b3 = reinterpret_cast(static_cast(nvls_mc_ptr) + nvls_off_barrier + 24); + uint64_t* dev_b3 = reinterpret_cast(static_cast(nvls_dev_ptr) + nvls_off_barrier + 24); + asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b3) : "memory"); + const uint64_t expected = nvls_epoch * static_cast(num_ranks); + uint64_t v; + do { + asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b3) : "memory"); + } while (v < expected); + } + __syncthreads(); + } else { + // TODO(chhwang): it should be a global barrier when kLowLatencyMode is false + const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank); + const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank); + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].wait(); + } + __syncthreads(); } - __syncthreads(); // Clean auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); @@ -1577,11 +1587,27 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in __threadfence_system(); __syncthreads(); - // Barrier again - if (run_barrier) { - port_channel_handles[barrier_channel_idx].signal(); - port_channel_handles[barrier_channel_idx].flush(); - port_channel_handles[barrier_channel_idx].wait(); + // Barrier again — Phase 4: second NVLS multimem.red barrier on slot +32. + if (nvls_mc_ptr != nullptr) { + if (thread_id == 0) { + uint64_t* mc_b4 = reinterpret_cast(static_cast(nvls_mc_ptr) + nvls_off_barrier + 32); + uint64_t* dev_b4 = reinterpret_cast(static_cast(nvls_dev_ptr) + nvls_off_barrier + 32); + asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b4) : "memory"); + const uint64_t expected = nvls_epoch * static_cast(num_ranks); + uint64_t v; + do { + asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b4) : "memory"); + } while (v < expected); + } + __syncthreads(); + } else { + const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank); + const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank); + if (run_barrier) { + port_channel_handles[barrier_channel_idx].signal(); + port_channel_handles[barrier_channel_idx].flush(); + port_channel_handles[barrier_channel_idx].wait(); + } } } else if (sm_id == 1) { // Barrier for NVL @@ -1659,7 +1685,8 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode, mscclpp::PortChannelDeviceHandle* port_channel_handles, - mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + void* nvls_mc_ptr, void* nvls_dev_ptr, size_t nvls_off_barrier, uint64_t nvls_epoch) { const int num_threads = std::max(128, 32 * num_channels); const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; @@ -1681,7 +1708,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks, is_cached_dispatch, port_channel_handles, - memory_channel_handles); + memory_channel_handles, nvls_mc_ptr, nvls_dev_ptr, nvls_off_barrier, nvls_epoch); } template @@ -1772,6 +1799,10 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, // NOTES: we decouple a channel into 2 SMs const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + if (rank == 0 && thread_id == 0 && channel_id == 0) { + printf("[ph4-CK] combine entry rank=%d sm=%d num_channels=%d is_rdma_recv=%d\n", + rank, sm_id, num_channels, (int)is_rdma_receiver_sm); + } auto role_meta = [=]() -> std::pair { auto warp_id = thread_id / 32; if (not is_rdma_receiver_sm) { @@ -2017,15 +2048,11 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, while (sub_warp_id == 0 and lane_id == 0) { // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` // Here, `token_start_idx` is the actual tail - // Phase 3: NVLS counter fast path. Slot keyed (producer=rdma_rank, - // consumer=dst_rdma_rank). I'm the producer here. - int num_used_slots; - if (nvls_head_dev != nullptr) { - num_used_slots = token_start_idx - - static_cast(nvls_ctr_load(nvls_head_dev, channel_id, rdma_rank, dst_rdma_rank)); - } else { - num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); - } + // Phase 4: head read — cross-node head feedback comes from peer + // via fabric-VA store, self-loop head from local atomic. Both end + // up in rdma_channel_head; one read path covers them. + int num_used_slots = token_start_idx - + static_cast(ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank))); if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) break; // Timeout check @@ -2123,9 +2150,25 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, handle.flush(); } } - if (lane_id == 0) { - nvls_ctr_add(nvls_tail_mc, channel_id, rdma_rank, dst_rdma_rank, - (uint64_t)num_chunked_tokens); + // Phase 4: tail counter publish. + // cross-node → direct fabric-VA st.release on peer's tail slot + // self-loop → plain local atomicAdd (NVLS multicast over-counts + // across NVL peers — see dispatch fix). + if (dst_rdma_rank != rdma_rank) { + if (peer_rdma_bases != nullptr && lane_id == 0) { + const int dst_rank_global = dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank; + const uintptr_t my_tail_off = + reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)) - + reinterpret_cast(rdma_buffer_ptr_base); + uint64_t* peer_tail = reinterpret_cast( + reinterpret_cast(peer_rdma_bases[dst_rank_global]) + my_tail_off); + const uint64_t new_tail = (uint64_t)(token_start_idx + num_chunked_tokens); + asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_tail), "l"(new_tail) : "memory"); + } + } else if (lane_id == 0) { + mscclpp::atomicFetchAdd( + reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), + (uint64_t)num_chunked_tokens, mscclpp::memoryOrderRelease); } } else if (lane_id == 0) { if (dst_rdma_rank == rdma_rank) { @@ -2188,14 +2231,10 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, // Wait lanes to be ready auto start_time = clock64(); while (cached_channel_tail_idx <= expected_head) { - // Phase 3: NVLS counter fast path. Slot keyed (producer=lane_id, - // consumer=rdma_rank). I'm the consumer here, peer is producer. - if (nvls_tail_dev != nullptr) { - cached_channel_tail_idx = static_cast( - nvls_ctr_load(nvls_tail_dev, channel_id, lane_id, rdma_rank)); - } else { - cached_channel_tail_idx = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id))); - } + // Phase 4: receiver waits on rdma_channel_tail. Cross-node sender + // wrote it via fabric-VA, self sender wrote it via local atomic. + // One read path via ld_acquire_sys_global covers both. + cached_channel_tail_idx = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id))); // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { @@ -2255,11 +2294,17 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - if (nvls_head_mc != nullptr) { - // Phase 3: NVLS counter fast path. Slot keyed (producer=dst_rdma_rank, - // consumer=rdma_rank). I'm consuming, peer is producer. - nvls_ctr_add(nvls_head_mc, channel_id, dst_rdma_rank, rdma_rank, - (uint64_t)(min_head - last_rdma_head)); + // Phase 4: head feedback path. + // cross-node → fabric-VA st.release on peer's head slot + // self-loop → local atomicAdd + if (peer_rdma_bases != nullptr && dst_rdma_rank != rdma_rank) { + const int dst_rank_global = dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank; + const uintptr_t my_head_off = + reinterpret_cast(rdma_channel_head.buffer(rdma_rank)) - + reinterpret_cast(rdma_buffer_ptr_base); + uint64_t* peer_head = reinterpret_cast( + reinterpret_cast(peer_rdma_bases[dst_rank_global]) + my_head_off); + asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_head), "l"((uint64_t)min_head) : "memory"); } else if (dst_rdma_rank == rdma_rank) { mscclpp::atomicFetchAdd(static_cast(rdma_channel_head.buffer(rdma_rank)), (uint64_t)(min_head - last_rdma_head), mscclpp::memoryOrderRelease);