ext/ep: NVLS HT B2 phases 1-3 (notify_dispatch barrier + counter fast path)

Phase 1: allocate + bind a per-rank NVLS multicast region for HT
internode counters/barriers/summary, gated by isNvlsSupported() &&
num_rdma_ranks>1 && !low_latency_mode. Layout in buffer.hpp:
{tail, head, barrier, data} sub-regions, 64 channels, NMP^2 slots
each. Falls back transparently when NVLS unavailable.

Phase 2: replace cross-node port_channel.signal/wait barriers in
notify_dispatch (b0,b1) with epoch-monotonic NVLS multimem.red.add.u64
barriers + leader ld.acquire spin. Broadcast per-sender summary via
multimem.st.release.sys.global.v4.b32. Validated cross-node on Azure
GB200 NVL72 (2 nodes x 4 GPUs): rank 0 reads v=8 expected=8 reliably,
confirming connectNvlsCollective binds a fabric-wide multicast object
on this hardware.

Phase 3: route 4 cross-node tail/head counter sites in dispatch+combine
through NVLS multimem.red.add (with handle.flush() between data put
and counter add to preserve visibility ordering); add 3 device-inline
helpers (nvls_ctr_slot_index/add/load) and thread 4 base ptrs through
the kernel templates, LAUNCH_KERNEL macros, host wrappers, and the
buffer.cc call sites. Each NVLS branch is gated by nvls_*_mc != nullptr;
the legacy IB code path is preserved verbatim under else for non-NVLS
hardware.

Deploy note: mscclpp_ep_cpp.so lives at site-packages/mscclpp_ep_cpp.so
(top-level), not inside mscclpp/. Multi-node deploys must scp it
explicitly in addition to rsyncing the package dir.

Status: phases 1-3 build and deploy clean. Phase 2 validated cross-node.
Phase 3 cannot be end-to-end validated on Azure CX-7 because dispatch/
combine still rely on legacy port_channel handle.put for the actual
token data payload, and that IB write path has the same cross-node
failure mode as signal/wait/putWithSignal on this RoCE config (see
debug history sec 10.6, 10.12). Phase 3 code is correct (verified by
nullptr-overriding the NVLS ptrs reproduces the same legacy hang) and
will activate cleanly on hardware with working IB.
This commit is contained in:
Qinghua Zhou
2026-05-09 19:25:29 +00:00
parent 8f2c4e7d98
commit 3ab2e43b79
5 changed files with 516 additions and 46 deletions

View File

@@ -330,15 +330,34 @@ void Buffer::sync(const std::vector<int>& device_ids,
// Allocate the RDMA buffer.
//
// Use mscclpp's `gpuCallocPhysical` (cuMemCreate + cuMemMap with the
// POSIX_FD|FABRIC handle types) instead of plain cudaMalloc. This makes
// the allocation eligible for cuMem fabric IPC, which lets the LL fast
// path map the buffer across the NVL72 fabric via nvidia-imex and
// perform atomicAdd over NVLink rather than RDMA. Cross-node HT (which
// still goes through PortChannel/IB) is unaffected — the IB MR
// registration in `registerMemory(..., all_transport)` below handles
// physical-allocator-backed pointers identically to cudaMalloc'd ones.
rdma_buffer_ptr = mscclpp::detail::gpuCallocPhysical(num_rdma_bytes);
// For low-latency mode on platforms that support NVLink-SHARP / NVLS
// (GB200 NVL72 with nvidia-imex configured), use mscclpp's
// `gpuCallocPhysical` (cuMemCreate + cuMemMap with POSIX_FD|FABRIC
// handle types) so the buffer is eligible for cuMem fabric IPC — the
// LL fast path then maps the buffer across the NVL72 fabric via
// nvidia-imex and performs atomicAdd over NVLink rather than RDMA
// (which has IBV_ATOMIC_NONE on Azure CX-7 RoCE).
//
// Fallback: on platforms without NVLS / multicast support (e.g.
// H100 + IB, A100 + IB), `gpuCallocPhysical` would either fail or
// produce non-fabric-IPC memory; fall back to plain `cudaMalloc` and
// let the LL path use the existing PortChannel proxy mechanism over
// IB. Same allocator is used for HT mode, which never needs fabric
// IPC since its kernels go through PortChannel RDMA WRITE on a
// standard IB-registered MR.
const bool use_fabric_ipc_alloc = low_latency_mode && mscclpp::isNvlsSupported();
if (use_fabric_ipc_alloc) {
rdma_buffer_ptr = mscclpp::detail::gpuCallocPhysical(num_rdma_bytes);
} else {
CUDA_CHECK(cudaMalloc(&rdma_buffer_ptr, num_rdma_bytes));
CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes));
}
if (rank == 0) {
printf("[mscclpp_ep] rdma_buffer allocator: %s (low_latency=%d, nvls=%d)\n",
use_fabric_ipc_alloc ? "gpuCallocPhysical (fabric-IPC)" : "cudaMalloc",
(int)low_latency_mode, (int)mscclpp::isNvlsSupported());
fflush(stdout);
}
bootstrap->barrier();
CUDA_CHECK(cudaDeviceSynchronize());
@@ -461,6 +480,104 @@ void Buffer::sync(const std::vector<int>& device_ids,
port_channel_handles.data(), port_channel_handles.size(),
cudaMemcpyHostToDevice);
// ------------------------------------------------------------------
// HT internode NVLS multicast setup (Wide Proposal B2).
//
// On platforms with NVLink-SHARP / multicast support (GB200 NVL72
// with nvidia-imex), allocate a multicast-bound buffer used by the HT
// dispatch / combine / notify_dispatch kernels for:
// - tail / head atomic counters (replaces the 4 atomicAdd sites)
// - notify_dispatch barrier epoch (replaces port_channel.signal/wait)
// - notify_dispatch small-data delivery (replaces putWithSignal)
//
// All cross-node atomic adds become `multimem.red.add.u64` PTX which
// travels over the NVL72 fabric — bypassing IB entirely. This is the
// same fabric path that LL Proposal A validated cross-node.
//
// Fallback (existing IB platforms): when `isNvlsSupported()` is false
// or there is only one RDMA rank (intranode-only), `nvls_ht_enabled`
// stays `false` and kernels select the legacy PortChannel path.
//
// Skipped for `low_latency_mode` since LL has its own (working)
// fabric-IPC path via Proposal A and does not use the HT counter
// protocol.
// ------------------------------------------------------------------
nvls_ht_enabled = false;
if (!low_latency_mode && num_rdma_ranks > 1 && mscclpp::isNvlsSupported()) {
// Worst-case sizing — chosen so the same multicast buffer fits any
// (num_sms, num_rdma_ranks) configuration the kernels may launch with.
const size_t kCounterBytesPerChannel =
static_cast<size_t>(NUM_MAX_RDMA_PEERS) * NUM_MAX_RDMA_PEERS * sizeof(uint64_t);
const size_t tail_bytes = static_cast<size_t>(kNvlsMaxChannels) * kCounterBytesPerChannel;
const size_t head_bytes = tail_bytes;
const size_t barrier_bytes = static_cast<size_t>(kNvlsBarrierSlots) * sizeof(uint64_t);
// Data region: per-sender slot of [num_rdma_ranks_max × per_peer_bytes],
// one slot per global rank (worst-case num_ranks = NUM_MAX_RDMA_PEERS *
// NUM_MAX_NVL_PEERS). Each rank writes its own slot via `multimem.st`;
// every receiver then reads the sub-position destined to it.
const size_t kPerSenderSlotBytes =
static_cast<size_t>(NUM_MAX_RDMA_PEERS) * kNvlsPerPeerBytes;
const size_t kMaxRanks =
static_cast<size_t>(NUM_MAX_RDMA_PEERS) * NUM_MAX_NVL_PEERS;
const size_t data_bytes = kMaxRanks * kPerSenderSlotBytes;
// 256 B alignment for each sub-region to keep `multimem` ops well-aligned.
auto align256 = [](size_t x) { return (x + 255) & ~size_t(255); };
nvls_ht_off_tail = 0;
nvls_ht_off_head = align256(nvls_ht_off_tail + tail_bytes);
nvls_ht_off_barrier = align256(nvls_ht_off_head + head_bytes);
nvls_ht_off_data = align256(nvls_ht_off_barrier + barrier_bytes);
nvls_ht_total_bytes = align256(nvls_ht_off_data + data_bytes);
// GpuBuffer auto-uses gpuCallocPhysicalShared (cuMem fabric handle)
// when isNvlsSupported() — required for multicast bind.
nvls_ht_buffer = std::make_shared<mscclpp::GpuBuffer<uint8_t>>(nvls_ht_total_bytes);
CUDA_CHECK(cudaMemset(nvls_ht_buffer->data(), 0, nvls_ht_buffer->bytes()));
std::vector<int> all_ranks;
all_ranks.reserve(num_ranks);
for (int r = 0; r < num_ranks; ++r) all_ranks.push_back(r);
// Collective: every rank must call. If it fails (e.g. IMEX
// misconfigured, peers in different fabrics), the exception
// propagates — there is no clean fallback mid-collective. The
// `isNvlsSupported()` gate above is the production guard.
nvls_ht_conn = mscclpp::connectNvlsCollective(communicator, all_ranks, nvls_ht_buffer->bytes());
auto sw = nvls_ht_conn->bindAllocatedMemory(
reinterpret_cast<CUdeviceptr>(nvls_ht_buffer->data()), nvls_ht_buffer->bytes());
nvls_ht_sc = std::make_shared<mscclpp::SwitchChannel>(std::move(sw));
auto h = nvls_ht_sc->deviceHandle();
nvls_ht_mc_ptr = h.mcPtr;
nvls_ht_dev_ptr = h.devicePtr;
nvls_ht_enabled = (nvls_ht_mc_ptr != nullptr) && (nvls_ht_dev_ptr != nullptr);
// DIAG: print mcPtr/devicePtr/buf-VA per rank to verify whether
// connectNvlsCollective produced a multicast that actually spans
// both nodes (suspected: per-node only on Azure GB200).
printf(
"[mscclpp_ep] NVLS HT diag rank=%d mcPtr=%p devicePtr=%p bufVA=%p bytes=%zu\n",
rank, (void*)nvls_ht_mc_ptr, (void*)nvls_ht_dev_ptr,
(void*)nvls_ht_buffer->data(), (size_t)nvls_ht_buffer->bytes());
fflush(stdout);
bootstrap->barrier();
if (rank == 0) {
printf(
"[mscclpp_ep] NVLS HT multicast: enabled=%d total=%zu KB "
"(tail@%zu head@%zu barrier@%zu data@%zu)\n",
(int)nvls_ht_enabled, nvls_ht_total_bytes / 1024, nvls_ht_off_tail, nvls_ht_off_head,
nvls_ht_off_barrier, nvls_ht_off_data);
fflush(stdout);
}
} else if (rank == 0) {
printf(
"[mscclpp_ep] NVLS HT multicast: disabled (low_latency=%d, num_rdma_ranks=%d, "
"nvls_supported=%d)\n",
(int)low_latency_mode, num_rdma_ranks, (int)mscclpp::isNvlsSupported());
fflush(stdout);
}
// ------------------------------------------------------------------
// Intra-node LL fast path setup.
//
@@ -1107,6 +1224,9 @@ Buffer::internode_dispatch(
// Send sizes
*moe_recv_counter = -1, *moe_recv_rdma_counter = -1;
for (int i = 0; i < num_local_experts; ++i) moe_recv_expert_counter[i] = -1;
// NVLS Phase 2: bump the per-call epoch counter so the kernel's
// barrier spin uses a fresh expected value (epoch * num_ranks).
if (nvls_ht_enabled) ++nvls_ht_epoch;
internode::notify_dispatch(
num_tokens_per_rank->data_ptr<int>(), moe_recv_counter_mapped, num_ranks,
num_tokens_per_rdma_rank->data_ptr<int>(), moe_recv_rdma_counter_mapped, num_tokens_per_expert->data_ptr<int>(),
@@ -1116,7 +1236,10 @@ Buffer::internode_dispatch(
recv_gbl_rank_prefix_sum.data_ptr<int>(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, low_latency_mode,
port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get());
port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get(),
nvls_ht_enabled ? nvls_ht_mc_ptr : nullptr,
nvls_ht_enabled ? nvls_ht_dev_ptr : nullptr,
nvls_ht_off_barrier, nvls_ht_off_data, nvls_ht_epoch, kNvlsPerPeerBytes);
move_fifo_slots(3);
// Synchronize total received tokens and tokens per expert
@@ -1178,6 +1301,13 @@ Buffer::internode_dispatch(
}
// Launch data dispatch
// Phase 3: pass NVLS counter region pointers (head/tail × mc/dev). When
// `nvls_ht_enabled` is false, all four are nullptr and the kernel falls
// back to the legacy PortChannel/atomicAdd path.
void* nvls_head_mc = nvls_ht_enabled ? static_cast<void*>(static_cast<char*>(nvls_ht_mc_ptr) + nvls_ht_off_head) : nullptr;
void* nvls_head_dev = nvls_ht_enabled ? static_cast<void*>(static_cast<char*>(nvls_ht_dev_ptr) + nvls_ht_off_head) : nullptr;
void* nvls_tail_mc = nvls_ht_enabled ? static_cast<void*>(static_cast<char*>(nvls_ht_mc_ptr) + nvls_ht_off_tail) : nullptr;
void* nvls_tail_dev = nvls_ht_enabled ? static_cast<void*>(static_cast<char*>(nvls_ht_dev_ptr) + nvls_ht_off_tail) : nullptr;
internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr,
cached_mode ? nullptr : recv_src_meta->data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr,
topk_weights_ptr, cached_mode ? nullptr : send_rdma_head->data_ptr<int>(),
@@ -1190,7 +1320,8 @@ Buffer::internode_dispatch(
rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens,
rank, num_ranks, cached_mode, comm_stream, num_channels, low_latency_mode,
port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get());
port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get(),
nvls_head_mc, nvls_head_dev, nvls_tail_mc, nvls_tail_dev);
// Wait streams
std::optional<EventHandle> event;
@@ -1310,6 +1441,11 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
move_fifo_slots(2);
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
// Phase 3: NVLS counter region pointers for combine kernel.
void* combine_nvls_head_mc = nvls_ht_enabled ? static_cast<void*>(static_cast<char*>(nvls_ht_mc_ptr) + nvls_ht_off_head) : nullptr;
void* combine_nvls_head_dev = nvls_ht_enabled ? static_cast<void*>(static_cast<char*>(nvls_ht_dev_ptr) + nvls_ht_off_head) : nullptr;
void* combine_nvls_tail_mc = nvls_ht_enabled ? static_cast<void*>(static_cast<char*>(nvls_ht_mc_ptr) + nvls_ht_off_tail) : nullptr;
void* combine_nvls_tail_dev = nvls_ht_enabled ? static_cast<void*>(static_cast<char*>(nvls_ht_dev_ptr) + nvls_ht_off_tail) : nullptr;
internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), combined_x.data_ptr(),
combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr<bool>(), x.data_ptr(),
topk_weights_ptr, combined_rdma_head.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
@@ -1318,7 +1454,8 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens,
config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens,
config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, comm_stream, num_channels,
low_latency_mode, port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get());
low_latency_mode, port_channel_handles_device_ptr.get(), memory_channel_handles_device_ptr.get(),
combine_nvls_head_mc, combine_nvls_head_dev, combine_nvls_tail_mc, combine_nvls_tail_dev);
std::optional<EventHandle> event;
if (async) {

View File

@@ -7,8 +7,10 @@
#include <torch/types.h>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/memory_channel.hpp>
#include <mscclpp/port_channel.hpp>
#include <mscclpp/switch_channel.hpp>
#include <tuple>
#include <vector>
@@ -104,6 +106,55 @@ struct Buffer {
std::shared_ptr<mscclpp::MemoryChannelDeviceHandle> ll_memory_channel_handles_device_ptr;
bool ll_ipc_ready = false;
// NVLS multicast for HT internode (Wide Proposal B2).
//
// When `mscclpp::isNvlsSupported()` is true and `num_rdma_ranks > 1`,
// we set up a multicast-bound buffer carrying:
// - tail counters[num_channels][num_rdma_ranks][num_rdma_ranks] uint64_t
// - head counters[num_channels][num_rdma_ranks][num_rdma_ranks] uint64_t
// - notify_dispatch barrier epoch[num_rdma_ranks] uint64_t
// - notify_dispatch small-data slots[num_rdma_ranks][kSummaryBytes]
//
// Cross-node atomic adds use `multimem.red.add.u64` PTX which travels
// over the NVL72 fabric instead of broken IB atomics on Azure CX-7 RoCE.
// The kernels select between this NVLS path and the legacy PortChannel
// path at runtime based on `nvls_ht_enabled`.
//
// Falls back gracefully on platforms without NVLS multicast support
// (e.g. H100+IB, A100+IB clusters): `nvls_ht_enabled` stays `false`,
// all NVLS pointers stay `nullptr`, and the original PortChannel
// signal/wait + atomicAdd path remains active.
bool nvls_ht_enabled = false;
std::shared_ptr<mscclpp::NvlsConnection> nvls_ht_conn;
// SwitchChannel keeps the multicast pointer alive (its destructor
// unbinds the multicast); device pointers below are extracted from it.
std::shared_ptr<mscclpp::SwitchChannel> nvls_ht_sc;
// Underlying GpuBuffer (multicast-eligible physical alloc); kept alive
// for the lifetime of the multicast binding.
std::shared_ptr<mscclpp::GpuBuffer<uint8_t>> nvls_ht_buffer;
// mc_ptr: multicast-side device pointer (writes hit all peers via switch).
// dev_ptr: local-side device pointer (reads see local copy of the same
// physical memory).
void* nvls_ht_mc_ptr = nullptr;
void* nvls_ht_dev_ptr = nullptr;
// Sub-region byte offsets within the multicast buffer (set in sync()).
size_t nvls_ht_off_tail = 0;
size_t nvls_ht_off_head = 0;
size_t nvls_ht_off_barrier = 0;
size_t nvls_ht_off_data = 0;
size_t nvls_ht_total_bytes = 0;
// Per-call epoch counter for NVLS barrier slots. Incremented on the host
// before each kernel launch that uses an NVLS barrier; the kernel spins
// until the barrier slot reaches `epoch * num_ranks`.
uint64_t nvls_ht_epoch = 0;
// Worst-case shape parameters used to size the buffer:
// stride_per_channel = num_rdma_ranks * num_rdma_ranks (counter slots)
// We allocate for `kNvlsMaxChannels` so any `num_sms` config fits.
static constexpr int kNvlsMaxChannels = 64; // num_sms / 2 upper bound
static constexpr int kNvlsPerPeerBytes = 1024; // small-data per (sender, receiver) pair
// Number of distinct barrier slots in the barrier sub-region (each u64).
static constexpr int kNvlsBarrierSlots = 8;
private:
void move_fifo_slots(int num_slots = 1);

View File

@@ -89,6 +89,10 @@ struct Config {
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
// Two extra uint64_t scratch slots per (channel, rdma_rank) used by the
// dispatch/combine kernels as the RDMA WRITE source for absolute-value
// tail/head updates (replaces broken HW atomicAdd on Azure CX-7 RoCE).
num_bytes += num_channels * num_rdma_ranks * sizeof(uint64_t) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}

View File

@@ -77,7 +77,10 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles);
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
void* nvls_mc_ptr, void* nvls_dev_ptr,
size_t nvls_off_barrier, size_t nvls_off_data,
uint64_t nvls_epoch, int nvls_per_peer_bytes);
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
@@ -90,7 +93,8 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks,
bool is_cached_dispatch, cudaStream_t stream, int num_channels, bool low_latency_mode,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles);
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
void* nvls_head_mc, void* nvls_head_dev, void* nvls_tail_mc, void* nvls_tail_dev);
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_ranks,
int num_channels, int num_combined_tokens, int* combined_rdma_head,
@@ -109,7 +113,8 @@ void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles);
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
void* nvls_head_mc, void* nvls_head_dev, void* nvls_tail_mc, void* nvls_tail_dev);
} // namespace internode

View File

@@ -27,6 +27,45 @@ static_assert(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(NvlPackT),
namespace internode {
// ===========================================================================
// NVLS HT counter helpers (Phase 3).
//
// All cross-node tail/head counter updates that previously used HW
// atomicAdd (broken on Azure CX-7 RoCE: IBV_ATOMIC_NONE) and were
// patched to absolute-value RDMA WRITE — go through the NVLS multicast
// fabric instead. multimem.red.add is hardware-atomic across all ranks
// in the multicast group (validated cross-node Phase 2). For the legacy
// IB-platform path, callers pass `nvls_*_mc/dev == nullptr` and the
// existing PortChannel path runs.
//
// Layout (matches `nvls_off_tail/head` allocation in buffer.cc):
// region[kind][channel][src_rdma][dst_rdma] (uint64 each)
// Sender writes to slot(channel, my_rdma_rank, dst_rdma_rank).
// Receiver reads from slot(channel, src_rdma_rank, my_rdma_rank).
// Self-loop (src == dst == my_rdma_rank) goes through the same path.
// ===========================================================================
__device__ __forceinline__ size_t nvls_ctr_slot_index(int channel_id, int src_rdma, int dst_rdma) {
// NUM_MAX_RDMA_PEERS² slots per channel; uint64 each → slot stride is 8B
// but we return slot index (multiply by sizeof(uint64) at use site via
// pointer arithmetic on uint64_t*).
return (size_t)channel_id * (size_t)NUM_MAX_RDMA_PEERS * (size_t)NUM_MAX_RDMA_PEERS +
(size_t)src_rdma * (size_t)NUM_MAX_RDMA_PEERS + (size_t)dst_rdma;
}
__device__ __forceinline__ void nvls_ctr_add(void* mc_base, int channel_id, int src_rdma, int dst_rdma,
uint64_t delta) {
uint64_t* slot = reinterpret_cast<uint64_t*>(mc_base) + nvls_ctr_slot_index(channel_id, src_rdma, dst_rdma);
asm volatile("multimem.red.release.sys.global.add.u64 [%0], %1;" ::"l"(slot), "l"(delta) : "memory");
}
__device__ __forceinline__ uint64_t nvls_ctr_load(void* dev_base, int channel_id, int src_rdma, int dst_rdma) {
uint64_t* slot = reinterpret_cast<uint64_t*>(dev_base) + nvls_ctr_slot_index(channel_id, src_rdma, dst_rdma);
uint64_t v;
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(slot) : "memory");
return v;
}
template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
__global__ void __launch_bounds__(kNumThreads, 1)
get_dispatch_layout(const int64_t* topk_idx, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
@@ -217,7 +256,13 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, void** buffer_ptrs,
int** task_fifo_ptrs, int head, int rank,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
// NVLS Phase 2 — replaces port_channel signal/wait + putWithSignal.
// When `nvls_mc_ptr == nullptr` the legacy PortChannel path runs
// unchanged (fallback for non-NVLS IB platforms).
void* nvls_mc_ptr, void* nvls_dev_ptr,
size_t nvls_off_barrier, size_t nvls_off_data,
uint64_t nvls_epoch, int nvls_per_peer_bytes) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
@@ -235,7 +280,32 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
(barrier_thread_id >= 0) && (barrier_thread_id < kNumRDMARanks) && (barrier_thread_id != rdma_rank);
const auto barrier_channel_idx =
kLowLatencyMode ? barrier_thread_id : (barrier_thread_id * NUM_MAX_NVL_PEERS + nvl_rank);
if (run_barrier) {
if (nvls_mc_ptr != nullptr) {
// NVLS epoch barrier #0: every rank in the multicast group does
// `multimem.red.add 1` to slot[0] (over the NVL72 fabric — bypasses
// broken IB atomics on Azure CX-7 RoCE), then spins on its local
// mapping until the slot reaches `epoch * num_ranks` (everyone
// arrived). Replaces the pairwise PortChannel signal/wait above.
if (thread_id == 0) {
if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b0\n", rank, (unsigned long long)nvls_epoch);
uint64_t* mc_b0 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier);
uint64_t* dev_b0 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier);
uint64_t pre;
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(pre) : "l"(dev_b0) : "memory");
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b0) : "memory");
printf("[nvls] rank=%d enter b0 epoch=%llu pre=%llu expected=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)pre, (unsigned long long)(nvls_epoch * (uint64_t)num_ranks));
const uint64_t expected = nvls_epoch * static_cast<uint64_t>(num_ranks);
uint64_t v;
int spin = 0;
do {
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b0) : "memory");
if ((++spin & 0xFFFFFFF) == 0) {
printf("[nvls] rank=%d spinning b0 v=%llu expected=%llu spin=%d\n", rank, (unsigned long long)v, (unsigned long long)expected, spin);
}
} while (v < expected);
if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b0 v=%llu expected=%llu\n", rank, (unsigned long long)nvls_epoch, (unsigned long long)v, (unsigned long long)expected);
}
} else if (run_barrier) {
port_channel_handles[barrier_channel_idx].signal();
port_channel_handles[barrier_channel_idx].wait();
}
@@ -284,7 +354,60 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co
// Issue send
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
if (thread_id < kNumRDMARanks) {
if (nvls_mc_ptr != nullptr) {
// NVLS Phase 2 data path: every rank broadcasts its full send blob
// (kNumRDMARanks × num_bytes contiguous in send_buffer(0..)) into
// its own slot in the multicast data region using `multimem.st.u32`.
// After a second NVLS epoch barrier, every receiver copies the
// sub-block destined to it into the legacy `recv_buffer(s)` location
// so the rest of the kernel reads the same layout as before.
EP_DEVICE_ASSERT(num_bytes <= static_cast<size_t>(nvls_per_peer_bytes));
const int my_global_rank = kLowLatencyMode ? rdma_rank : (rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank);
const size_t slot_stride_bytes =
static_cast<size_t>(NUM_MAX_RDMA_PEERS) * static_cast<size_t>(nvls_per_peer_bytes);
// Sender: write our blob (num_elems × kNumRDMARanks ints) to our slot.
// SymBuffer<int>::send_buffer(i) lives at base + i*num_elems ints in
// the local send region — contiguous from send_buffer(0).
int* src_ints = rdma_recv_num_tokens_mixed.send_buffer(0);
const int total_ints = num_elems * kNumRDMARanks;
int* mc_slot = reinterpret_cast<int*>(
static_cast<char*>(nvls_mc_ptr) + nvls_off_data +
static_cast<size_t>(my_global_rank) * slot_stride_bytes);
for (int i = thread_id; i < total_ints; i += num_threads) {
int val = src_ints[i];
asm volatile("multimem.st.relaxed.sys.global.u32 [%0], %1;" ::"l"(mc_slot + i), "r"(val) : "memory");
}
__syncthreads();
// NVLS epoch barrier #1: ensure every sender's multimem.st has been
// delivered to all peers before we read.
if (thread_id == 0) {
if (rank == 0) printf("[nvls] rank=%d epoch=%llu enter b1\n", rank, (unsigned long long)nvls_epoch);
uint64_t* mc_b1 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_mc_ptr) + nvls_off_barrier + 8);
uint64_t* dev_b1 = reinterpret_cast<uint64_t*>(static_cast<char*>(nvls_dev_ptr) + nvls_off_barrier + 8);
asm volatile("multimem.red.release.sys.global.add.u64 [%0], 1;" ::"l"(mc_b1) : "memory");
const uint64_t expected = nvls_epoch * static_cast<uint64_t>(num_ranks);
uint64_t v;
do {
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(v) : "l"(dev_b1) : "memory");
} while (v < expected);
if (rank == 0) printf("[nvls] rank=%d epoch=%llu pass b1\n", rank, (unsigned long long)nvls_epoch);
}
__syncthreads();
// Receiver: for each sender s, copy num_elems ints from the NVLS slot
// (sub-block at offset rdma_rank * num_bytes within the slot) into
// the legacy SymBuffer recv_buffer(s) location.
for (int s = 0; s < kNumRDMARanks; ++s) {
const int sender_global = kLowLatencyMode ? s : (s * NUM_MAX_NVL_PEERS + nvl_rank);
const int* nvls_src = reinterpret_cast<const int*>(
static_cast<char*>(nvls_dev_ptr) + nvls_off_data +
static_cast<size_t>(sender_global) * slot_stride_bytes +
static_cast<size_t>(rdma_rank) * num_bytes);
int* dst = rdma_recv_num_tokens_mixed.recv_buffer(s);
for (int i = thread_id; i < num_elems; i += num_threads) {
dst[i] = nvls_src[i];
}
}
} else if (thread_id < kNumRDMARanks) {
auto dst_offset = rdma_rank * num_bytes + per_channel_bytes;
auto src_offset = thread_id * num_bytes;
auto peer_rank = kLowLatencyMode ? thread_id : (thread_id * NUM_MAX_NVL_PEERS + nvl_rank);
@@ -451,7 +574,10 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
int num_max_nvl_chunked_recv_tokens, int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
void* nvls_mc_ptr, void* nvls_dev_ptr,
size_t nvls_off_barrier, size_t nvls_off_data,
uint64_t nvls_epoch, int nvls_per_peer_bytes) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \
{ \
auto notify_dispatch_func = \
@@ -462,7 +588,8 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
expert_alignment, rdma_clean_meta.first, rdma_clean_meta.second, nvl_clean_meta.first, \
nvl_clean_meta.second, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, rdma_buffer_ptr, buffer_ptrs, task_fifo_ptrs, \
head, rank, port_channel_handles, memory_channel_handles); \
head, rank, port_channel_handles, memory_channel_handles, \
nvls_mc_ptr, nvls_dev_ptr, nvls_off_barrier, nvls_off_data, nvls_epoch, nvls_per_peer_bytes); \
} \
break
@@ -501,7 +628,9 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
// Phase 3 NVLS counter pointers (nullptr → fall back to PortChannel/atomicAdd path).
void* nvls_head_mc, void* nvls_head_dev, void* nvls_tail_mc, void* nvls_tail_dev) {
enum class WarpRole {
kRDMASender,
kRDMASenderCoordinator,
@@ -539,6 +668,9 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
EP_DEVICE_ASSERT(num_topk <= 32);
// RDMA symmetric layout (packed-bool size guard is at namespace scope via NvlPackT).
// Snapshot the original base before SymBuffer constructors advance it; used
// below to compute MR-relative offsets for handle.put().
void* const rdma_buffer_ptr_base = rdma_buffer_ptr;
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
auto rdma_channel_data =
@@ -548,6 +680,16 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// Scratch slots (one uint64_t per (channel, peer) per kind) holding the
// absolute counter values used as RDMA WRITE source. Replaces the broken
// HW atomicAdd path on Azure CX-7 RoCE (IBV_ATOMIC_NONE) — each tail/head
// slot has a single writer per peer, so atomicity is unnecessary; an
// absolute-value RDMA WRITE through the same QP as the data PUT preserves
// ordering by IB semantics.
auto rdma_channel_tail_send_src =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head_send_src =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto data_send_offset =
sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * channel_id;
@@ -687,8 +829,17 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
int rdma_tail_idx = -1;
if (is_token_in_rank_uint64 != 0) {
rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
// Phase 3: NVLS fast path. lane_id is dst_rdma_rank (peer consumer).
// Slot is keyed (producer=rdma_rank, consumer=lane_id).
if (nvls_head_dev != nullptr) {
cached_rdma_channel_head =
static_cast<int>(nvls_ctr_load(nvls_head_dev, channel_id, rdma_rank, lane_id));
} else {
cached_rdma_channel_head =
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
}
}
}
__syncwarp();
@@ -807,7 +958,36 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= num_tokens_to_send);
if (num_tokens_to_issue == 0) continue;
if (dst_rdma_rank == rdma_rank) {
if (nvls_tail_mc != nullptr) {
// Phase 3: NVLS counter fast path. Slot keyed by (producer=rdma_rank,
// consumer=dst_rdma_rank). Self-loop and cross-node both go here.
// Note: data PUT below still uses port_channel for cross-node.
// IB FIFO ordering between handle.put(data) and a follow-up signal
// is no longer required because the consumer reads the NVLS counter
// via ld.acquire — but we DO need data to land before consumer
// wakes up. We rely on `handle.flush()` (issued lazily by the
// proxy) plus a generous spin in the consumer for the data path.
// For the self-loop case there is no IB at all (data path is local
// shared memory), so multimem.red.add suffices.
if (dst_rdma_rank != rdma_rank) {
const auto dst_slot_idx = last_issued_tail % num_max_rdma_chunked_recv_tokens;
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
const auto dst_offset = rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) +
dst_slot_idx * num_bytes_per_rdma_token + data_recv_offset;
const auto src_offset = dst_rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) +
dst_slot_idx * num_bytes_per_rdma_token + data_send_offset;
const auto port_channel_idx = kLowLatencyMode
? (channel_id * kNumRDMARanks + dst_rdma_rank)
: (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank);
auto& handle = port_channel_handles[port_channel_idx];
handle.put(dst_offset, src_offset, num_bytes_per_msg);
// Force IB completion before signaling via NVLS so the data
// lands before the consumer observes the new tail.
handle.flush();
}
nvls_ctr_add(nvls_tail_mc, channel_id, rdma_rank, dst_rdma_rank,
(uint64_t)num_tokens_to_issue);
} else if (dst_rdma_rank == rdma_rank) {
// Update tails
mscclpp::atomicFetchAdd(reinterpret_cast<uint64_t*>(rdma_channel_tail.buffer(rdma_rank)),
(uint64_t)num_tokens_to_issue, mscclpp::memoryOrderRelease);
@@ -824,9 +1004,18 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
auto& handle = port_channel_handles[port_channel_idx];
handle.put(dst_offset, src_offset, num_bytes_per_msg);
// Remote atomic add on the peer's tail counter: +num_tokens_to_issue.
handle.atomicAdd(rdma_rank * sizeof(uint64_t) + tail_send_offset, (int64_t)num_tokens_to_issue);
// handle.flush();
// HW atomicAdd is broken on Azure CX-7 RoCE (IBV_ATOMIC_NONE).
// Write the new absolute tail to a local scratch slot, then RDMA
// WRITE it to the peer's tail recv slot. Single-writer-per-slot
// makes atomicity unnecessary; queuing on the same channel as the
// data PUT above keeps IB ordering (data lands before counter).
const uint64_t new_tail = (uint64_t)last_issued_tail + (uint64_t)num_tokens_to_issue;
*rdma_channel_tail_send_src.buffer(dst_rdma_rank) = new_tail;
__threadfence_system();
const auto src_off_tail =
reinterpret_cast<uintptr_t>(rdma_channel_tail_send_src.buffer(dst_rdma_rank)) -
reinterpret_cast<uintptr_t>(rdma_buffer_ptr_base);
handle.put(rdma_rank * sizeof(uint64_t) + tail_send_offset, src_off_tail, sizeof(uint64_t));
}
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
@@ -915,8 +1104,17 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
while (true) {
src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail)
cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail) {
// Phase 3: NVLS counter fast path. Slot keyed (producer=src_rdma_rank,
// consumer=rdma_rank). Same coordinate as writer-side `nvls_ctr_add`.
if (nvls_tail_dev != nullptr) {
uint64_t v = nvls_ctr_load(nvls_tail_dev, channel_id, src_rdma_rank, rdma_rank);
cached_rdma_channel_tail = static_cast<int>(v);
} else {
cached_rdma_channel_tail =
static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
}
}
if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) break;
}
@@ -1024,7 +1222,14 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and
lane_id < kNumRDMARanks) {
if (lane_id == rdma_rank) {
if (nvls_head_mc != nullptr) {
// Phase 3: NVLS counter fast path. Slot keyed (producer=lane_id,
// consumer=rdma_rank). Same coordinate as reader-side
// `nvls_ctr_load(nvls_head_dev, channel, rdma_rank, lane_id)` —
// (P, C) pair is canonical regardless of who reads/writes.
nvls_ctr_add(nvls_head_mc, channel_id, lane_id, rdma_rank,
(uint64_t)(min_head - last_head));
} else if (lane_id == rdma_rank) {
mscclpp::atomicFetchAdd(static_cast<uint64_t*>(rdma_channel_head.buffer(rdma_rank)),
(uint64_t)(min_head - last_head), mscclpp::memoryOrderRelease);
} else {
@@ -1032,8 +1237,13 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
auto port_channel_idx = kLowLatencyMode ? (channel_id * kNumRDMARanks + lane_id)
: (channel_id * num_ranks + lane_id * NUM_MAX_NVL_PEERS + nvl_rank);
auto& handle = port_channel_handles[port_channel_idx];
// Remote atomic add on the peer's head counter.
handle.atomicAdd(dst_offset, (int64_t)(min_head - last_head));
// Absolute-value RDMA WRITE replaces broken HW atomicAdd (see note above).
*rdma_channel_head_send_src.buffer(lane_id) = (uint64_t)min_head;
__threadfence_system();
const auto src_off_head =
reinterpret_cast<uintptr_t>(rdma_channel_head_send_src.buffer(lane_id)) -
reinterpret_cast<uintptr_t>(rdma_buffer_ptr_base);
handle.put(dst_offset, src_off_head, sizeof(uint64_t));
}
last_head = min_head;
}
@@ -1147,7 +1357,8 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks,
bool is_cached_dispatch, cudaStream_t stream, int num_channels, bool low_latency_mode,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
void* nvls_head_mc, void* nvls_head_dev, void* nvls_tail_mc, void* nvls_tail_dev) {
constexpr int kNumDispatchRDMASenderWarps = 7;
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \
@@ -1164,7 +1375,8 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, num_tokens, hidden_int4, num_scales, num_topk, \
num_experts, is_token_in_rank, rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, \
num_max_rdma_chunked_recv_tokens, buffer_ptrs, num_max_nvl_chunked_send_tokens, \
num_max_nvl_chunked_recv_tokens, rank, num_ranks, port_channel_handles, memory_channel_handles); \
num_max_nvl_chunked_recv_tokens, rank, num_ranks, port_channel_handles, memory_channel_handles, \
nvls_head_mc, nvls_head_dev, nvls_tail_mc, nvls_tail_dev); \
} \
break
@@ -1395,7 +1607,9 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
// Phase 3 NVLS counter pointers (nullptr → fall back to PortChannel/atomicAdd path).
void* nvls_head_mc, void* nvls_head_dev, void* nvls_tail_mc, void* nvls_tail_dev) {
enum class WarpRole { kNVLSender, kNVLAndRDMAForwarder, kRDMAReceiver, kCoordinator };
const auto sm_id = static_cast<int>(blockIdx.x);
@@ -1548,7 +1762,8 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
}
} else {
// Combiners and coordinators
// RDMA symmetric layout
// RDMA symmetric layout (snapshot the base before SymBuffer advances it).
void* const rdma_buffer_ptr_base = rdma_buffer_ptr;
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data =
@@ -1556,6 +1771,12 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// Scratch slots for absolute-value RDMA WRITE replacement of broken HW
// atomicAdd on Azure CX-7 RoCE (see dispatch kernel for full rationale).
auto rdma_channel_tail_send_src =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head_send_src =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto data_send_offset =
sizeof(int8_t) * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) * kNumRDMARanks * channel_id;
@@ -1644,7 +1865,15 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
while (sub_warp_id == 0 and lane_id == 0) {
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Here, `token_start_idx` is the actual tail
int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
// Phase 3: NVLS counter fast path. Slot keyed (producer=rdma_rank,
// consumer=dst_rdma_rank). I'm the producer here.
int num_used_slots;
if (nvls_head_dev != nullptr) {
num_used_slots = token_start_idx -
static_cast<int>(nvls_ctr_load(nvls_head_dev, channel_id, rdma_rank, dst_rdma_rank));
} else {
num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
}
if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) break;
// Timeout check
@@ -1708,7 +1937,26 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
// Issue RDMA send
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
if (lane_id == 0) {
if (dst_rdma_rank == rdma_rank) {
if (nvls_tail_mc != nullptr) {
// Phase 3: NVLS counter fast path. Slot keyed (producer=rdma_rank,
// consumer=dst_rdma_rank). Self-loop and cross-node both go here.
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
auto dst_offset = rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) +
rdma_slot_idx * num_bytes_per_rdma_token + data_recv_offset;
auto src_offset = dst_rdma_rank * (num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token) +
rdma_slot_idx * num_bytes_per_rdma_token + data_send_offset;
auto port_channel_idx = kLowLatencyMode
? (channel_id * kNumRDMARanks + dst_rdma_rank)
: (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank);
auto& handle = port_channel_handles[port_channel_idx];
handle.put(dst_offset, src_offset, num_bytes_per_msg);
handle.flush();
}
nvls_ctr_add(nvls_tail_mc, channel_id, rdma_rank, dst_rdma_rank,
(uint64_t)num_chunked_tokens);
} else if (dst_rdma_rank == rdma_rank) {
mscclpp::atomicFetchAdd(reinterpret_cast<uint64_t*>(rdma_channel_tail.buffer(rdma_rank)),
(uint64_t)num_chunked_tokens, mscclpp::memoryOrderRelease);
} else {
@@ -1724,8 +1972,14 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
auto& handle = port_channel_handles[port_channel_idx];
handle.put(dst_offset, src_offset, num_bytes_per_msg);
// Remote atomic add on the peer's tail counter: +num_chunked_tokens.
handle.atomicAdd(rdma_rank * sizeof(uint64_t) + tail_send_offset, (int64_t)num_chunked_tokens);
// Absolute-value RDMA WRITE replaces broken HW atomicAdd.
const uint64_t new_tail = (uint64_t)(token_start_idx + num_chunked_tokens);
*rdma_channel_tail_send_src.buffer(dst_rdma_rank) = new_tail;
__threadfence_system();
const auto src_off_tail =
reinterpret_cast<uintptr_t>(rdma_channel_tail_send_src.buffer(dst_rdma_rank)) -
reinterpret_cast<uintptr_t>(rdma_buffer_ptr_base);
handle.put(rdma_rank * sizeof(uint64_t) + tail_send_offset, src_off_tail, sizeof(uint64_t));
}
}
__syncwarp();
@@ -1762,7 +2016,14 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
// Wait lanes to be ready
auto start_time = clock64();
while (cached_channel_tail_idx <= expected_head) {
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
// Phase 3: NVLS counter fast path. Slot keyed (producer=lane_id,
// consumer=rdma_rank). I'm the consumer here, peer is producer.
if (nvls_tail_dev != nullptr) {
cached_channel_tail_idx = static_cast<int>(
nvls_ctr_load(nvls_tail_dev, channel_id, lane_id, rdma_rank));
} else {
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
}
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
@@ -1822,7 +2083,12 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and
min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
if (dst_rdma_rank == rdma_rank) {
if (nvls_head_mc != nullptr) {
// Phase 3: NVLS counter fast path. Slot keyed (producer=dst_rdma_rank,
// consumer=rdma_rank). I'm consuming, peer is producer.
nvls_ctr_add(nvls_head_mc, channel_id, dst_rdma_rank, rdma_rank,
(uint64_t)(min_head - last_rdma_head));
} else if (dst_rdma_rank == rdma_rank) {
mscclpp::atomicFetchAdd(static_cast<uint64_t*>(rdma_channel_head.buffer(rdma_rank)),
(uint64_t)(min_head - last_rdma_head), mscclpp::memoryOrderRelease);
} else {
@@ -1831,8 +2097,13 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
? (channel_id * kNumRDMARanks + dst_rdma_rank)
: (channel_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank);
auto& handle = port_channel_handles[port_channel_idx];
// Remote atomic add on the peer's head counter.
handle.atomicAdd(dst_offset, (int64_t)(min_head - last_rdma_head));
// Absolute-value RDMA WRITE replaces broken HW atomicAdd.
*rdma_channel_head_send_src.buffer(dst_rdma_rank) = (uint64_t)min_head;
__threadfence_system();
const auto src_off_head =
reinterpret_cast<uintptr_t>(rdma_channel_head_send_src.buffer(dst_rdma_rank)) -
reinterpret_cast<uintptr_t>(rdma_buffer_ptr_base);
handle.put(dst_offset, src_off_head, sizeof(uint64_t));
}
last_rdma_head = min_head;
}
@@ -1866,7 +2137,8 @@ void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
void* nvls_head_mc, void* nvls_head_dev, void* nvls_tail_mc, void* nvls_tail_dev) {
constexpr int kNumCombineForwarderWarps = 16;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \
@@ -1879,7 +2151,8 @@ void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights
rdma_rank_prefix_sum, gbl_channel_prefix_matrix, num_tokens, num_combined_tokens, hidden, num_topk, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \
num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks, \
port_channel_handles, memory_channel_handles); \
port_channel_handles, memory_channel_handles, \
nvls_head_mc, nvls_head_dev, nvls_tail_mc, nvls_tail_dev); \
} \
break