ext/ep: WIP Phase 4 fix NVLS self-overcount + cached_notify NVLS barrier

Two related root causes prevented Phase 4 (internode dispatch+combine) from
completing on Azure GB200 NVL72 with the IB control-plane disabled.

1) NVLS self-loop over-count
   The sender/forwarder counter publishes used multimem.red.add, which
   multicasts to every NVL peer that has the buffer bound. When a single
   logical writer issues an add, every peer adds it to its own slot, so
   self-loop counters (where one rank is both writer and reader on the
   same (P,C) pair) over-count by N = number of NVL peers.

   Fix: replace all NVLS-based self-counter sites in dispatch+combine
   with plain local mscclpp::atomicFetchAdd. Cross-node visibility was
   already handled separately via direct fabric-VA st.release.sys.global
   on peer_rdma_bases.

2) cached_notify barrier hang on Azure CX-7 RoCE
   The two port_channel.signal/wait pairs in cached_notify hang on this
   platform because RoCE control-plane traffic is broken.

   Fix: add an NVLS multimem.red barrier path (barrier slots +24 / +32)
   that mirrors the working notify_dispatch pattern. Threaded the
   nvls_mc_ptr / nvls_dev_ptr / nvls_off_barrier / nvls_epoch params
   through api.cuh + buffer.cc; introduced a separate
   nvls_ht_cached_epoch member because slots +24 / +32 are only touched
   when the cached path is taken — sharing the global nvls_ht_epoch
   would mismatch slot increments and wait expectations.

End-to-end test_internode.py: dispatch + combine PASS with
max|got - expected| = 0.0 across all ranks.
This commit is contained in:
Qinghua Zhou
2026-05-10 07:23:36 +00:00
parent f2228b07bb
commit e0a1bb2c42
4 changed files with 152 additions and 93 deletions

View File

@@ -1228,12 +1228,16 @@ Buffer::internode_dispatch(
recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();
// Just a barrier and clean flags
if (nvls_ht_enabled) ++nvls_ht_cached_epoch;
internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr, nullptr,
nullptr, nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank,
comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, true, low_latency_mode, port_channel_handles_device_ptr.get(),
memory_channel_handles_device_ptr.get());
memory_channel_handles_device_ptr.get(),
nvls_ht_enabled ? nvls_ht_mc_ptr : nullptr,
nvls_ht_enabled ? nvls_ht_dev_ptr : nullptr,
nvls_ht_off_barrier, nvls_ht_cached_epoch);
move_fifo_slots(2);
} else {
rdma_channel_prefix_matrix =
@@ -1459,7 +1463,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
combined_nvl_head.data_ptr<int>(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu,
config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode,
port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get());
port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get(),
nvls_ht_enabled ? nvls_ht_mc_ptr : nullptr,
nvls_ht_enabled ? nvls_ht_dev_ptr : nullptr,
nvls_ht_off_barrier, (nvls_ht_enabled ? ++nvls_ht_cached_epoch : 0));
move_fifo_slots(2);
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());

View File

@@ -147,6 +147,11 @@ struct Buffer {
// before each kernel launch that uses an NVLS barrier; the kernel spins
// until the barrier slot reaches `epoch * num_ranks`.
uint64_t nvls_ht_epoch = 0;
// Independent epoch for cached_notify barrier slots (offsets +24 / +32),
// since those slots are only touched when the cached path is taken — using
// the shared `nvls_ht_epoch` would over-count the expected value relative
// to the number of times those particular slots have actually been bumped.
uint64_t nvls_ht_cached_epoch = 0;
// Worst-case shape parameters used to size the buffer:
// stride_per_channel = num_rdma_ranks * num_rdma_ranks (counter slots)
// We allocate for `kNvlsMaxChannels` so any `num_sms` config fits.

View File

@@ -104,7 +104,9 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles);
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
void* nvls_mc_ptr = nullptr, void* nvls_dev_ptr = nullptr,
size_t nvls_off_barrier = 0, uint64_t nvls_epoch = 0);
void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank,
const void* x, const float* topk_weights, const int* combined_rdma_head, const int* combined_nvl_head,

View File

@@ -880,21 +880,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
if (is_token_in_rank_uint64 != 0) {
rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
// Phase 4: prefer the legacy local-load when peer_rdma_bases
// is in use \u2014 cross-node head feedback now writes the
// absolute value directly into our local rdma_channel_head
// slot via fabric VA, so a `ld_volatile_global` here sees it.
if (peer_rdma_bases != nullptr && lane_id != rdma_rank) {
cached_rdma_channel_head =
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
} else if (nvls_head_dev != nullptr) {
// Phase 3: NVLS fast path (self loop only when fabric VA on).
cached_rdma_channel_head =
static_cast<int>(nvls_ctr_load(nvls_head_dev, channel_id, rdma_rank, lane_id));
} else {
cached_rdma_channel_head =
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
}
// Phase 4: head feedback path \u2014 cross-node uses fabric-VA store,
// self-loop uses local atomic. Both end up in rdma_channel_head;
// single read via ld_volatile_global covers them. NVLS removed.
cached_rdma_channel_head =
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
}
}
__syncwarp();
@@ -1072,11 +1062,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
}
// Owner advances tail counter for this peer.
// Phase 4 fix: cross-node tail goes through direct fabric-VA
// store on peer's rdma_channel_tail slot (single writer, no
// atomic needed). The NVLS multimem.red.relaxed counter path
// is unreliable for cross-node visibility — empirically only
// partial increments are seen by the consumer's ld.acquire.
// Self path (dst == self) still uses NVLS for intra-rank.
// store on peer's rdma_channel_tail slot. Self-loop tail goes
// through plain local atomicAdd on rdma_channel_tail.buffer(rdma_rank)
// — NVLS multicast is WRONG for self-loop because it fans out
// to all bound NVL peers' buffers (4 NVL ranks × n_issue ⇒ 4x
// over-count on each consumer's read).
if (owner) {
if (peer_rdma_bases != nullptr && dst_rdma_rank != rdma_rank) {
const int dst_rank_global = dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank;
@@ -1087,12 +1077,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
reinterpret_cast<uint8_t*>(peer_rdma_bases[dst_rank_global]) + my_tail_off);
const uint64_t new_tail = (uint64_t)issue_tail + (uint64_t)n_issue;
asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_tail), "l"(new_tail) : "memory");
if (rank == 0) {
printf("[ph4-T] sender wrote tail rank=%d ch=%d dst=%d new_tail=%lu peer=%p\n",
rank, channel_id, dst_rdma_rank, (unsigned long)new_tail, (void*)peer_tail);
}
} else {
nvls_ctr_add(nvls_tail_mc, channel_id, rdma_rank, dst_rdma_rank, (uint64_t)n_issue);
// Self-loop: plain release atomic on local slot (no multicast).
mscclpp::atomicFetchAdd(
reinterpret_cast<uint64_t*>(rdma_channel_tail.buffer(rdma_rank)),
(uint64_t)n_issue, mscclpp::memoryOrderRelease);
}
}
} else if (owner) {
@@ -1230,20 +1219,12 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
// fabric-VA store to local rdma_channel_tail.buffer(src_rdma_rank),
// so prefer the legacy ld_acquire read which sees those stores.
// NVLS counter only used for self path (single rank).
if (peer_rdma_bases != nullptr && src_rdma_rank != rdma_rank) {
cached_rdma_channel_tail =
static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
if (rank == 4 && lane_id == src_rdma_rank && cached_rdma_channel_tail > cached_rdma_channel_head) {
printf("[ph4-Tr] fwd read tail rank=%d ch=%d src=%d val=%d head=%d\n",
rank, channel_id, src_rdma_rank, cached_rdma_channel_tail, cached_rdma_channel_head);
}
} else if (nvls_tail_dev != nullptr) {
uint64_t v = nvls_ctr_load(nvls_tail_dev, channel_id, src_rdma_rank, rdma_rank);
cached_rdma_channel_tail = static_cast<int>(v);
} else {
cached_rdma_channel_tail =
static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
}
// Phase 4 fix: cross-node tail comes via direct fabric-VA store
// (sender writes peer's rdma_channel_tail slot). Self-loop tail
// is a plain local atomic. Both end up in rdma_channel_tail —
// single read path via ld_acquire_sys_global covers them.
cached_rdma_channel_tail =
static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
}
if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) break;
}
@@ -1261,6 +1242,12 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank);
auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank);
if (rank == 4 && lane_id == 0 && dst_nvl_rank == 0 && channel_id == 0) {
printf("[ph4-Loop] fwd loop rank=%d ch=%d dst_nvl=%d src_rdma=%d head=%d tail=%d nrecv=%d\n",
rank, channel_id, dst_nvl_rank, src_rdma_rank, src_rdma_head, src_rdma_tail,
num_tokens_to_recv_from_rdma);
}
// Iterate over every token from the RDMA buffer
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
@@ -1324,7 +1311,12 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
// Retired
__syncwarp();
if (lane_id == 0) forward_channel_retired[dst_nvl_rank] = true;
if (lane_id == 0) {
forward_channel_retired[dst_nvl_rank] = true;
if (channel_id == 0) {
printf("[ph4-Fr] fwd retire rank=%d ch=%d dst_nvl=%d\n", rank, channel_id, dst_nvl_rank);
}
}
} else if (warp_role == WarpRole::kForwarderCoordinator) {
// Extra warps for forwarder coordinator should exit directly
if (target_rank > 0) return;
@@ -1365,14 +1357,10 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
uint64_t* peer_head = reinterpret_cast<uint64_t*>(
reinterpret_cast<uint8_t*>(peer_rdma_bases[dst_rank_global]) + my_head_off);
asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_head), "l"((uint64_t)min_head) : "memory");
} else if (nvls_head_mc != nullptr) {
// Phase 3: NVLS counter fast path. Slot keyed (producer=lane_id,
// consumer=rdma_rank). Same coordinate as reader-side
// `nvls_ctr_load(nvls_head_dev, channel, rdma_rank, lane_id)` —
// (P, C) pair is canonical regardless of who reads/writes.
nvls_ctr_add(nvls_head_mc, channel_id, lane_id, rdma_rank,
(uint64_t)(min_head - last_head));
} else if (lane_id == rdma_rank) {
// Self-loop: plain release atomic on local slot. Cannot use NVLS
// multimem here \u2014 it fans out to all NVL peers' local buffers
// and over-counts (4 NVL ranks \u00d7 add \u21d2 4x increment).
mscclpp::atomicFetchAdd(static_cast<uint64_t*>(rdma_channel_head.buffer(rdma_rank)),
(uint64_t)(min_head - last_head), mscclpp::memoryOrderRelease);
} else {
@@ -1487,7 +1475,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
if (lane_id == 0) st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
}
}
if (thread_id == 0 && channel_id == 0 && rank == 0) {
if (thread_id == 0 && channel_id == 0) {
printf("[ph4-Z] dispatch exit rank=%d sm=%d\n", rank, sm_id);
}
}
@@ -1542,7 +1530,12 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
int* combined_nvl_head, void* rdma_buffer_ptr, void** buffer_ptrs, int** task_fifo_ptrs,
int head, int rank, int num_ranks, bool is_cached_dispatch,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
// Phase 4: NVLS multimem barrier path — bypasses port_channel signal/wait
// (broken on Azure CX-7 RoCE) by using multimem.red.add + ld.acquire on a
// pair of barrier slots (offsets +24 and +32). Both nullptr ⇒ fall back to IB.
void* nvls_mc_ptr, void* nvls_dev_ptr, size_t nvls_off_barrier,
uint64_t nvls_epoch) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x);
auto num_threads = static_cast<int>(blockDim.x);
@@ -1555,16 +1548,33 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
// Using two SMs, which clean the RDMA/NVL buffer respectively
if (sm_id == 0) {
// Barrier for RDMA
// TODO(chhwang): it should be a global barrier when kLowLatencyMode is false
const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank);
const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank);
if (run_barrier) {
port_channel_handles[barrier_channel_idx].signal();
port_channel_handles[barrier_channel_idx].wait();
if (thread_id == 0) {
printf("[ph4-CN0] cached_notify sm0 entry rank=%d\n", rank);
}
// Barrier for RDMA — Phase 4: NVLS multimem.red.add fast path replaces
// the port_channel signal/wait pair (broken on Azure CX-7).
if (nvls_mc_ptr != nullptr) {
if (thread_id == 0) {
uint64_t* mc_b3 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier + 24);
uint64_t* dev_b3 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier + 24);
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b3) : "memory");
const uint64_t expected = nvls_epoch * static_cast<uint64_t>(num_ranks);
uint64_t v;
do {
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b3) : "memory");
} while (v < expected);
}
__syncthreads();
} else {
// TODO(chhwang): it should be a global barrier when kLowLatencyMode is false
const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank);
const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank);
if (run_barrier) {
port_channel_handles[barrier_channel_idx].signal();
port_channel_handles[barrier_channel_idx].wait();
}
__syncthreads();
}
__syncthreads();
// Clean
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
@@ -1577,11 +1587,27 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
__threadfence_system();
__syncthreads();
// Barrier again
if (run_barrier) {
port_channel_handles[barrier_channel_idx].signal();
port_channel_handles[barrier_channel_idx].flush();
port_channel_handles[barrier_channel_idx].wait();
// Barrier again — Phase 4: second NVLS multimem.red barrier on slot +32.
if (nvls_mc_ptr != nullptr) {
if (thread_id == 0) {
uint64_t* mc_b4 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier + 32);
uint64_t* dev_b4 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier + 32);
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b4) : "memory");
const uint64_t expected = nvls_epoch * static_cast<uint64_t>(num_ranks);
uint64_t v;
do {
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b4) : "memory");
} while (v < expected);
}
__syncthreads();
} else {
const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank);
const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank);
if (run_barrier) {
port_channel_handles[barrier_channel_idx].signal();
port_channel_handles[barrier_channel_idx].flush();
port_channel_handles[barrier_channel_idx].wait();
}
}
} else if (sm_id == 1) {
// Barrier for NVL
@@ -1659,7 +1685,8 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
void* nvls_mc_ptr, void* nvls_dev_ptr, size_t nvls_off_barrier, uint64_t nvls_epoch) {
const int num_threads = std::max(128, 32 * num_channels);
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
@@ -1681,7 +1708,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, num_channels,
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, buffer_ptrs,
task_fifo_ptrs, head, rank, num_ranks, is_cached_dispatch, port_channel_handles,
memory_channel_handles);
memory_channel_handles, nvls_mc_ptr, nvls_dev_ptr, nvls_off_barrier, nvls_epoch);
}
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
@@ -1772,6 +1799,10 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
if (rank == 0 && thread_id == 0 && channel_id == 0) {
printf("[ph4-CK] combine entry rank=%d sm=%d num_channels=%d is_rdma_recv=%d\n",
rank, sm_id, num_channels, (int)is_rdma_receiver_sm);
}
auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / 32;
if (not is_rdma_receiver_sm) {
@@ -2017,15 +2048,11 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
while (sub_warp_id == 0 and lane_id == 0) {
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Here, `token_start_idx` is the actual tail
// Phase 3: NVLS counter fast path. Slot keyed (producer=rdma_rank,
// consumer=dst_rdma_rank). I'm the producer here.
int num_used_slots;
if (nvls_head_dev != nullptr) {
num_used_slots = token_start_idx -
static_cast<int>(nvls_ctr_load(nvls_head_dev, channel_id, rdma_rank, dst_rdma_rank));
} else {
num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
}
// Phase 4: head read — cross-node head feedback comes from peer
// via fabric-VA store, self-loop head from local atomic. Both end
// up in rdma_channel_head; one read path covers them.
int num_used_slots = token_start_idx -
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)));
if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) break;
// Timeout check
@@ -2123,9 +2150,25 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
handle.flush();
}
}
if (lane_id == 0) {
nvls_ctr_add(nvls_tail_mc, channel_id, rdma_rank, dst_rdma_rank,
(uint64_t)num_chunked_tokens);
// Phase 4: tail counter publish.
// cross-node → direct fabric-VA st.release on peer's tail slot
// self-loop → plain local atomicAdd (NVLS multicast over-counts
// across NVL peers — see dispatch fix).
if (dst_rdma_rank != rdma_rank) {
if (peer_rdma_bases != nullptr && lane_id == 0) {
const int dst_rank_global = dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank;
const uintptr_t my_tail_off =
reinterpret_cast<uintptr_t>(rdma_channel_tail.buffer(rdma_rank)) -
reinterpret_cast<uintptr_t>(rdma_buffer_ptr_base);
uint64_t* peer_tail = reinterpret_cast<uint64_t*>(
reinterpret_cast<uint8_t*>(peer_rdma_bases[dst_rank_global]) + my_tail_off);
const uint64_t new_tail = (uint64_t)(token_start_idx + num_chunked_tokens);
asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_tail), "l"(new_tail) : "memory");
}
} else if (lane_id == 0) {
mscclpp::atomicFetchAdd(
reinterpret_cast<uint64_t*>(rdma_channel_tail.buffer(rdma_rank)),
(uint64_t)num_chunked_tokens, mscclpp::memoryOrderRelease);
}
} else if (lane_id == 0) {
if (dst_rdma_rank == rdma_rank) {
@@ -2188,14 +2231,10 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
// Wait lanes to be ready
auto start_time = clock64();
while (cached_channel_tail_idx <= expected_head) {
// Phase 3: NVLS counter fast path. Slot keyed (producer=lane_id,
// consumer=rdma_rank). I'm the consumer here, peer is producer.
if (nvls_tail_dev != nullptr) {
cached_channel_tail_idx = static_cast<int>(
nvls_ctr_load(nvls_tail_dev, channel_id, lane_id, rdma_rank));
} else {
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
}
// Phase 4: receiver waits on rdma_channel_tail. Cross-node sender
// wrote it via fabric-VA, self sender wrote it via local atomic.
// One read path via ld_acquire_sys_global covers both.
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
@@ -2255,11 +2294,17 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and
min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
if (nvls_head_mc != nullptr) {
// Phase 3: NVLS counter fast path. Slot keyed (producer=dst_rdma_rank,
// consumer=rdma_rank). I'm consuming, peer is producer.
nvls_ctr_add(nvls_head_mc, channel_id, dst_rdma_rank, rdma_rank,
(uint64_t)(min_head - last_rdma_head));
// Phase 4: head feedback path.
// cross-node → fabric-VA st.release on peer's head slot
// self-loop → local atomicAdd
if (peer_rdma_bases != nullptr && dst_rdma_rank != rdma_rank) {
const int dst_rank_global = dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank;
const uintptr_t my_head_off =
reinterpret_cast<uintptr_t>(rdma_channel_head.buffer(rdma_rank)) -
reinterpret_cast<uintptr_t>(rdma_buffer_ptr_base);
uint64_t* peer_head = reinterpret_cast<uint64_t*>(
reinterpret_cast<uint8_t*>(peer_rdma_bases[dst_rank_global]) + my_head_off);
asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_head), "l"((uint64_t)min_head) : "memory");
} else if (dst_rdma_rank == rdma_rank) {
mscclpp::atomicFetchAdd(static_cast<uint64_t*>(rdma_channel_head.buffer(rdma_rank)),
(uint64_t)(min_head - last_rdma_head), mscclpp::memoryOrderRelease);