mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
ext/ep: WIP Phase 4 fix NVLS self-overcount + cached_notify NVLS barrier
Two related root causes prevented Phase 4 (internode dispatch+combine) from completing on Azure GB200 NVL72 with the IB control-plane disabled. 1) NVLS self-loop over-count The sender/forwarder counter publishes used multimem.red.add, which multicasts to every NVL peer that has the buffer bound. When a single logical writer issues an add, every peer adds it to its own slot, so self-loop counters (where one rank is both writer and reader on the same (P,C) pair) over-count by N = number of NVL peers. Fix: replace all NVLS-based self-counter sites in dispatch+combine with plain local mscclpp::atomicFetchAdd. Cross-node visibility was already handled separately via direct fabric-VA st.release.sys.global on peer_rdma_bases. 2) cached_notify barrier hang on Azure CX-7 RoCE The two port_channel.signal/wait pairs in cached_notify hang on this platform because RoCE control-plane traffic is broken. Fix: add an NVLS multimem.red barrier path (barrier slots +24 / +32) that mirrors the working notify_dispatch pattern. Threaded the nvls_mc_ptr / nvls_dev_ptr / nvls_off_barrier / nvls_epoch params through api.cuh + buffer.cc; introduced a separate nvls_ht_cached_epoch member because slots +24 / +32 are only touched when the cached path is taken — sharing the global nvls_ht_epoch would mismatch slot increments and wait expectations. End-to-end test_internode.py: dispatch + combine PASS with max|got - expected| = 0.0 across all ranks.
This commit is contained in:
@@ -1228,12 +1228,16 @@ Buffer::internode_dispatch(
|
||||
recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();
|
||||
|
||||
// Just a barrier and clean flags
|
||||
if (nvls_ht_enabled) ++nvls_ht_cached_epoch;
|
||||
internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr, nullptr,
|
||||
nullptr, nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
|
||||
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank,
|
||||
comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
|
||||
num_nvl_bytes, true, low_latency_mode, port_channel_handles_device_ptr.get(),
|
||||
memory_channel_handles_device_ptr.get());
|
||||
memory_channel_handles_device_ptr.get(),
|
||||
nvls_ht_enabled ? nvls_ht_mc_ptr : nullptr,
|
||||
nvls_ht_enabled ? nvls_ht_dev_ptr : nullptr,
|
||||
nvls_ht_off_barrier, nvls_ht_cached_epoch);
|
||||
move_fifo_slots(2);
|
||||
} else {
|
||||
rdma_channel_prefix_matrix =
|
||||
@@ -1459,7 +1463,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
|
||||
combined_nvl_head.data_ptr<int>(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu,
|
||||
config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream,
|
||||
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode,
|
||||
port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get());
|
||||
port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get(),
|
||||
nvls_ht_enabled ? nvls_ht_mc_ptr : nullptr,
|
||||
nvls_ht_enabled ? nvls_ht_dev_ptr : nullptr,
|
||||
nvls_ht_off_barrier, (nvls_ht_enabled ? ++nvls_ht_cached_epoch : 0));
|
||||
move_fifo_slots(2);
|
||||
|
||||
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||
|
||||
@@ -147,6 +147,11 @@ struct Buffer {
|
||||
// before each kernel launch that uses an NVLS barrier; the kernel spins
|
||||
// until the barrier slot reaches `epoch * num_ranks`.
|
||||
uint64_t nvls_ht_epoch = 0;
|
||||
// Independent epoch for cached_notify barrier slots (offsets +24 / +32),
|
||||
// since those slots are only touched when the cached path is taken — using
|
||||
// the shared `nvls_ht_epoch` would over-count the expected value relative
|
||||
// to the number of times those particular slots have actually been bumped.
|
||||
uint64_t nvls_ht_cached_epoch = 0;
|
||||
// Worst-case shape parameters used to size the buffer:
|
||||
// stride_per_channel = num_rdma_ranks * num_rdma_ranks (counter slots)
|
||||
// We allocate for `kNvlsMaxChannels` so any `num_sms` config fits.
|
||||
|
||||
@@ -104,7 +104,9 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
|
||||
int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles);
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
void* nvls_mc_ptr = nullptr, void* nvls_dev_ptr = nullptr,
|
||||
size_t nvls_off_barrier = 0, uint64_t nvls_epoch = 0);
|
||||
|
||||
void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank,
|
||||
const void* x, const float* topk_weights, const int* combined_rdma_head, const int* combined_nvl_head,
|
||||
|
||||
@@ -880,21 +880,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
if (is_token_in_rank_uint64 != 0) {
|
||||
rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
|
||||
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
|
||||
// Phase 4: prefer the legacy local-load when peer_rdma_bases
|
||||
// is in use \u2014 cross-node head feedback now writes the
|
||||
// absolute value directly into our local rdma_channel_head
|
||||
// slot via fabric VA, so a `ld_volatile_global` here sees it.
|
||||
if (peer_rdma_bases != nullptr && lane_id != rdma_rank) {
|
||||
cached_rdma_channel_head =
|
||||
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
||||
} else if (nvls_head_dev != nullptr) {
|
||||
// Phase 3: NVLS fast path (self loop only when fabric VA on).
|
||||
cached_rdma_channel_head =
|
||||
static_cast<int>(nvls_ctr_load(nvls_head_dev, channel_id, rdma_rank, lane_id));
|
||||
} else {
|
||||
cached_rdma_channel_head =
|
||||
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
||||
}
|
||||
// Phase 4: head feedback path \u2014 cross-node uses fabric-VA store,
|
||||
// self-loop uses local atomic. Both end up in rdma_channel_head;
|
||||
// single read via ld_volatile_global covers them. NVLS removed.
|
||||
cached_rdma_channel_head =
|
||||
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
@@ -1072,11 +1062,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
}
|
||||
// Owner advances tail counter for this peer.
|
||||
// Phase 4 fix: cross-node tail goes through direct fabric-VA
|
||||
// store on peer's rdma_channel_tail slot (single writer, no
|
||||
// atomic needed). The NVLS multimem.red.relaxed counter path
|
||||
// is unreliable for cross-node visibility — empirically only
|
||||
// partial increments are seen by the consumer's ld.acquire.
|
||||
// Self path (dst == self) still uses NVLS for intra-rank.
|
||||
// store on peer's rdma_channel_tail slot. Self-loop tail goes
|
||||
// through plain local atomicAdd on rdma_channel_tail.buffer(rdma_rank)
|
||||
// — NVLS multicast is WRONG for self-loop because it fans out
|
||||
// to all bound NVL peers' buffers (4 NVL ranks × n_issue ⇒ 4x
|
||||
// over-count on each consumer's read).
|
||||
if (owner) {
|
||||
if (peer_rdma_bases != nullptr && dst_rdma_rank != rdma_rank) {
|
||||
const int dst_rank_global = dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank;
|
||||
@@ -1087,12 +1077,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
reinterpret_cast<uint8_t*>(peer_rdma_bases[dst_rank_global]) + my_tail_off);
|
||||
const uint64_t new_tail = (uint64_t)issue_tail + (uint64_t)n_issue;
|
||||
asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_tail), "l"(new_tail) : "memory");
|
||||
if (rank == 0) {
|
||||
printf("[ph4-T] sender wrote tail rank=%d ch=%d dst=%d new_tail=%lu peer=%p\n",
|
||||
rank, channel_id, dst_rdma_rank, (unsigned long)new_tail, (void*)peer_tail);
|
||||
}
|
||||
} else {
|
||||
nvls_ctr_add(nvls_tail_mc, channel_id, rdma_rank, dst_rdma_rank, (uint64_t)n_issue);
|
||||
// Self-loop: plain release atomic on local slot (no multicast).
|
||||
mscclpp::atomicFetchAdd(
|
||||
reinterpret_cast<uint64_t*>(rdma_channel_tail.buffer(rdma_rank)),
|
||||
(uint64_t)n_issue, mscclpp::memoryOrderRelease);
|
||||
}
|
||||
}
|
||||
} else if (owner) {
|
||||
@@ -1230,20 +1219,12 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
// fabric-VA store to local rdma_channel_tail.buffer(src_rdma_rank),
|
||||
// so prefer the legacy ld_acquire read which sees those stores.
|
||||
// NVLS counter only used for self path (single rank).
|
||||
if (peer_rdma_bases != nullptr && src_rdma_rank != rdma_rank) {
|
||||
cached_rdma_channel_tail =
|
||||
static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
|
||||
if (rank == 4 && lane_id == src_rdma_rank && cached_rdma_channel_tail > cached_rdma_channel_head) {
|
||||
printf("[ph4-Tr] fwd read tail rank=%d ch=%d src=%d val=%d head=%d\n",
|
||||
rank, channel_id, src_rdma_rank, cached_rdma_channel_tail, cached_rdma_channel_head);
|
||||
}
|
||||
} else if (nvls_tail_dev != nullptr) {
|
||||
uint64_t v = nvls_ctr_load(nvls_tail_dev, channel_id, src_rdma_rank, rdma_rank);
|
||||
cached_rdma_channel_tail = static_cast<int>(v);
|
||||
} else {
|
||||
cached_rdma_channel_tail =
|
||||
static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
|
||||
}
|
||||
// Phase 4 fix: cross-node tail comes via direct fabric-VA store
|
||||
// (sender writes peer's rdma_channel_tail slot). Self-loop tail
|
||||
// is a plain local atomic. Both end up in rdma_channel_tail —
|
||||
// single read path via ld_acquire_sys_global covers them.
|
||||
cached_rdma_channel_tail =
|
||||
static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
|
||||
}
|
||||
if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) break;
|
||||
}
|
||||
@@ -1261,6 +1242,12 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank);
|
||||
auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank);
|
||||
|
||||
if (rank == 4 && lane_id == 0 && dst_nvl_rank == 0 && channel_id == 0) {
|
||||
printf("[ph4-Loop] fwd loop rank=%d ch=%d dst_nvl=%d src_rdma=%d head=%d tail=%d nrecv=%d\n",
|
||||
rank, channel_id, dst_nvl_rank, src_rdma_rank, src_rdma_head, src_rdma_tail,
|
||||
num_tokens_to_recv_from_rdma);
|
||||
}
|
||||
|
||||
// Iterate over every token from the RDMA buffer
|
||||
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
|
||||
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
|
||||
@@ -1324,7 +1311,12 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
|
||||
// Retired
|
||||
__syncwarp();
|
||||
if (lane_id == 0) forward_channel_retired[dst_nvl_rank] = true;
|
||||
if (lane_id == 0) {
|
||||
forward_channel_retired[dst_nvl_rank] = true;
|
||||
if (channel_id == 0) {
|
||||
printf("[ph4-Fr] fwd retire rank=%d ch=%d dst_nvl=%d\n", rank, channel_id, dst_nvl_rank);
|
||||
}
|
||||
}
|
||||
} else if (warp_role == WarpRole::kForwarderCoordinator) {
|
||||
// Extra warps for forwarder coordinator should exit directly
|
||||
if (target_rank > 0) return;
|
||||
@@ -1365,14 +1357,10 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
uint64_t* peer_head = reinterpret_cast<uint64_t*>(
|
||||
reinterpret_cast<uint8_t*>(peer_rdma_bases[dst_rank_global]) + my_head_off);
|
||||
asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_head), "l"((uint64_t)min_head) : "memory");
|
||||
} else if (nvls_head_mc != nullptr) {
|
||||
// Phase 3: NVLS counter fast path. Slot keyed (producer=lane_id,
|
||||
// consumer=rdma_rank). Same coordinate as reader-side
|
||||
// `nvls_ctr_load(nvls_head_dev, channel, rdma_rank, lane_id)` —
|
||||
// (P, C) pair is canonical regardless of who reads/writes.
|
||||
nvls_ctr_add(nvls_head_mc, channel_id, lane_id, rdma_rank,
|
||||
(uint64_t)(min_head - last_head));
|
||||
} else if (lane_id == rdma_rank) {
|
||||
// Self-loop: plain release atomic on local slot. Cannot use NVLS
|
||||
// multimem here \u2014 it fans out to all NVL peers' local buffers
|
||||
// and over-counts (4 NVL ranks \u00d7 add \u21d2 4x increment).
|
||||
mscclpp::atomicFetchAdd(static_cast<uint64_t*>(rdma_channel_head.buffer(rdma_rank)),
|
||||
(uint64_t)(min_head - last_head), mscclpp::memoryOrderRelease);
|
||||
} else {
|
||||
@@ -1487,7 +1475,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
if (lane_id == 0) st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
|
||||
}
|
||||
}
|
||||
if (thread_id == 0 && channel_id == 0 && rank == 0) {
|
||||
if (thread_id == 0 && channel_id == 0) {
|
||||
printf("[ph4-Z] dispatch exit rank=%d sm=%d\n", rank, sm_id);
|
||||
}
|
||||
}
|
||||
@@ -1542,7 +1530,12 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
int* combined_nvl_head, void* rdma_buffer_ptr, void** buffer_ptrs, int** task_fifo_ptrs,
|
||||
int head, int rank, int num_ranks, bool is_cached_dispatch,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
// Phase 4: NVLS multimem barrier path — bypasses port_channel signal/wait
|
||||
// (broken on Azure CX-7 RoCE) by using multimem.red.add + ld.acquire on a
|
||||
// pair of barrier slots (offsets +24 and +32). Both nullptr ⇒ fall back to IB.
|
||||
void* nvls_mc_ptr, void* nvls_dev_ptr, size_t nvls_off_barrier,
|
||||
uint64_t nvls_epoch) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
auto num_threads = static_cast<int>(blockDim.x);
|
||||
@@ -1555,16 +1548,33 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
|
||||
// Using two SMs, which clean the RDMA/NVL buffer respectively
|
||||
if (sm_id == 0) {
|
||||
// Barrier for RDMA
|
||||
|
||||
// TODO(chhwang): it should be a global barrier when kLowLatencyMode is false
|
||||
const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank);
|
||||
const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank);
|
||||
if (run_barrier) {
|
||||
port_channel_handles[barrier_channel_idx].signal();
|
||||
port_channel_handles[barrier_channel_idx].wait();
|
||||
if (thread_id == 0) {
|
||||
printf("[ph4-CN0] cached_notify sm0 entry rank=%d\n", rank);
|
||||
}
|
||||
// Barrier for RDMA — Phase 4: NVLS multimem.red.add fast path replaces
|
||||
// the port_channel signal/wait pair (broken on Azure CX-7).
|
||||
if (nvls_mc_ptr != nullptr) {
|
||||
if (thread_id == 0) {
|
||||
uint64_t* mc_b3 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier + 24);
|
||||
uint64_t* dev_b3 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier + 24);
|
||||
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b3) : "memory");
|
||||
const uint64_t expected = nvls_epoch * static_cast<uint64_t>(num_ranks);
|
||||
uint64_t v;
|
||||
do {
|
||||
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b3) : "memory");
|
||||
} while (v < expected);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
// TODO(chhwang): it should be a global barrier when kLowLatencyMode is false
|
||||
const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank);
|
||||
const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank);
|
||||
if (run_barrier) {
|
||||
port_channel_handles[barrier_channel_idx].signal();
|
||||
port_channel_handles[barrier_channel_idx].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Clean
|
||||
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
|
||||
@@ -1577,11 +1587,27 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
__threadfence_system();
|
||||
__syncthreads();
|
||||
|
||||
// Barrier again
|
||||
if (run_barrier) {
|
||||
port_channel_handles[barrier_channel_idx].signal();
|
||||
port_channel_handles[barrier_channel_idx].flush();
|
||||
port_channel_handles[barrier_channel_idx].wait();
|
||||
// Barrier again — Phase 4: second NVLS multimem.red barrier on slot +32.
|
||||
if (nvls_mc_ptr != nullptr) {
|
||||
if (thread_id == 0) {
|
||||
uint64_t* mc_b4 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier + 32);
|
||||
uint64_t* dev_b4 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier + 32);
|
||||
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b4) : "memory");
|
||||
const uint64_t expected = nvls_epoch * static_cast<uint64_t>(num_ranks);
|
||||
uint64_t v;
|
||||
do {
|
||||
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b4) : "memory");
|
||||
} while (v < expected);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
const bool run_barrier = (threadIdx.x < num_rdma_ranks) && (threadIdx.x != rdma_rank);
|
||||
const auto barrier_channel_idx = kLowLatencyMode ? threadIdx.x : (threadIdx.x * NUM_MAX_NVL_PEERS + nvl_rank);
|
||||
if (run_barrier) {
|
||||
port_channel_handles[barrier_channel_idx].signal();
|
||||
port_channel_handles[barrier_channel_idx].flush();
|
||||
port_channel_handles[barrier_channel_idx].wait();
|
||||
}
|
||||
}
|
||||
} else if (sm_id == 1) {
|
||||
// Barrier for NVL
|
||||
@@ -1659,7 +1685,8 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
|
||||
int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
void* nvls_mc_ptr, void* nvls_dev_ptr, size_t nvls_off_barrier, uint64_t nvls_epoch) {
|
||||
const int num_threads = std::max(128, 32 * num_channels);
|
||||
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
|
||||
|
||||
@@ -1681,7 +1708,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
nvl_clean_meta.second, combined_rdma_head, num_combined_tokens, num_channels,
|
||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, buffer_ptrs,
|
||||
task_fifo_ptrs, head, rank, num_ranks, is_cached_dispatch, port_channel_handles,
|
||||
memory_channel_handles);
|
||||
memory_channel_handles, nvls_mc_ptr, nvls_dev_ptr, nvls_off_barrier, nvls_epoch);
|
||||
}
|
||||
|
||||
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
|
||||
@@ -1772,6 +1799,10 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
|
||||
|
||||
// NOTES: we decouple a channel into 2 SMs
|
||||
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
||||
if (rank == 0 && thread_id == 0 && channel_id == 0) {
|
||||
printf("[ph4-CK] combine entry rank=%d sm=%d num_channels=%d is_rdma_recv=%d\n",
|
||||
rank, sm_id, num_channels, (int)is_rdma_receiver_sm);
|
||||
}
|
||||
auto role_meta = [=]() -> std::pair<WarpRole, int> {
|
||||
auto warp_id = thread_id / 32;
|
||||
if (not is_rdma_receiver_sm) {
|
||||
@@ -2017,15 +2048,11 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
|
||||
while (sub_warp_id == 0 and lane_id == 0) {
|
||||
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
|
||||
// Here, `token_start_idx` is the actual tail
|
||||
// Phase 3: NVLS counter fast path. Slot keyed (producer=rdma_rank,
|
||||
// consumer=dst_rdma_rank). I'm the producer here.
|
||||
int num_used_slots;
|
||||
if (nvls_head_dev != nullptr) {
|
||||
num_used_slots = token_start_idx -
|
||||
static_cast<int>(nvls_ctr_load(nvls_head_dev, channel_id, rdma_rank, dst_rdma_rank));
|
||||
} else {
|
||||
num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
|
||||
}
|
||||
// Phase 4: head read — cross-node head feedback comes from peer
|
||||
// via fabric-VA store, self-loop head from local atomic. Both end
|
||||
// up in rdma_channel_head; one read path covers them.
|
||||
int num_used_slots = token_start_idx -
|
||||
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)));
|
||||
if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) break;
|
||||
|
||||
// Timeout check
|
||||
@@ -2123,9 +2150,25 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
|
||||
handle.flush();
|
||||
}
|
||||
}
|
||||
if (lane_id == 0) {
|
||||
nvls_ctr_add(nvls_tail_mc, channel_id, rdma_rank, dst_rdma_rank,
|
||||
(uint64_t)num_chunked_tokens);
|
||||
// Phase 4: tail counter publish.
|
||||
// cross-node → direct fabric-VA st.release on peer's tail slot
|
||||
// self-loop → plain local atomicAdd (NVLS multicast over-counts
|
||||
// across NVL peers — see dispatch fix).
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
if (peer_rdma_bases != nullptr && lane_id == 0) {
|
||||
const int dst_rank_global = dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank;
|
||||
const uintptr_t my_tail_off =
|
||||
reinterpret_cast<uintptr_t>(rdma_channel_tail.buffer(rdma_rank)) -
|
||||
reinterpret_cast<uintptr_t>(rdma_buffer_ptr_base);
|
||||
uint64_t* peer_tail = reinterpret_cast<uint64_t*>(
|
||||
reinterpret_cast<uint8_t*>(peer_rdma_bases[dst_rank_global]) + my_tail_off);
|
||||
const uint64_t new_tail = (uint64_t)(token_start_idx + num_chunked_tokens);
|
||||
asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_tail), "l"(new_tail) : "memory");
|
||||
}
|
||||
} else if (lane_id == 0) {
|
||||
mscclpp::atomicFetchAdd(
|
||||
reinterpret_cast<uint64_t*>(rdma_channel_tail.buffer(rdma_rank)),
|
||||
(uint64_t)num_chunked_tokens, mscclpp::memoryOrderRelease);
|
||||
}
|
||||
} else if (lane_id == 0) {
|
||||
if (dst_rdma_rank == rdma_rank) {
|
||||
@@ -2188,14 +2231,10 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
|
||||
// Wait lanes to be ready
|
||||
auto start_time = clock64();
|
||||
while (cached_channel_tail_idx <= expected_head) {
|
||||
// Phase 3: NVLS counter fast path. Slot keyed (producer=lane_id,
|
||||
// consumer=rdma_rank). I'm the consumer here, peer is producer.
|
||||
if (nvls_tail_dev != nullptr) {
|
||||
cached_channel_tail_idx = static_cast<int>(
|
||||
nvls_ctr_load(nvls_tail_dev, channel_id, lane_id, rdma_rank));
|
||||
} else {
|
||||
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
|
||||
}
|
||||
// Phase 4: receiver waits on rdma_channel_tail. Cross-node sender
|
||||
// wrote it via fabric-VA, self sender wrote it via local atomic.
|
||||
// One read path via ld_acquire_sys_global covers both.
|
||||
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
|
||||
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
@@ -2255,11 +2294,17 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
|
||||
if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
|
||||
if (min_head != std::numeric_limits<int>::max() and
|
||||
min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
||||
if (nvls_head_mc != nullptr) {
|
||||
// Phase 3: NVLS counter fast path. Slot keyed (producer=dst_rdma_rank,
|
||||
// consumer=rdma_rank). I'm consuming, peer is producer.
|
||||
nvls_ctr_add(nvls_head_mc, channel_id, dst_rdma_rank, rdma_rank,
|
||||
(uint64_t)(min_head - last_rdma_head));
|
||||
// Phase 4: head feedback path.
|
||||
// cross-node → fabric-VA st.release on peer's head slot
|
||||
// self-loop → local atomicAdd
|
||||
if (peer_rdma_bases != nullptr && dst_rdma_rank != rdma_rank) {
|
||||
const int dst_rank_global = dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank;
|
||||
const uintptr_t my_head_off =
|
||||
reinterpret_cast<uintptr_t>(rdma_channel_head.buffer(rdma_rank)) -
|
||||
reinterpret_cast<uintptr_t>(rdma_buffer_ptr_base);
|
||||
uint64_t* peer_head = reinterpret_cast<uint64_t*>(
|
||||
reinterpret_cast<uint8_t*>(peer_rdma_bases[dst_rank_global]) + my_head_off);
|
||||
asm volatile("st.release.sys.global.u64 [%0], %1;" :: "l"(peer_head), "l"((uint64_t)min_head) : "memory");
|
||||
} else if (dst_rdma_rank == rdma_rank) {
|
||||
mscclpp::atomicFetchAdd(static_cast<uint64_t*>(rdma_channel_head.buffer(rdma_rank)),
|
||||
(uint64_t)(min_head - last_rdma_head), mscclpp::memoryOrderRelease);
|
||||
|
||||
Reference in New Issue
Block a user