ext/ep: fix LL IB atomicAdd alignment by widening signaling buffers to int64

The low-latency dispatch/combine kernels signal recv counts via MSCCL++
PortChannel.atomicAdd, which lowers to IB IBV_WR_ATOMIC_FETCH_AND_ADD.
That opcode requires the remote address to be 8-byte aligned, but
LowLatencyLayout packed the per-expert signaling slots as int32. Odd
slots landed at offset %8 == 4; the NIC silently dropped those atomics
and the target rank spun forever in recv_hook (observed: even->odd
direction works, odd->even does not, across all tested topologies
including 2-rank intra-node, 8-rank intra-node, and 2-node 1-GPU-each).

Widen dispatch_rdma_recv_count_buffer / combine_rdma_recv_flag_buffer to
int64_t, update clean kernel + kernel signatures + next_clean pointers
accordingly, and add int64_t overloads for st_na_release /
ld_acquire_sys_global in utils.cuh.

Also drop the bogus self CUDA-IPC connection in Buffer::sync() that was
previously skewing the cross-rank buildAndAddSemaphore handshake order;
the kernel's same-rank branch uses a direct warp copy and never touches
the self port-channel slot (filled with a zero-initialized placeholder
so the [local_expert*num_ranks + dst_rank] indexing still holds).
This commit is contained in:
Qinghua Zhou
2026-04-23 06:31:18 +00:00
parent c51a8a5305
commit 1e430874ce
5 changed files with 64 additions and 38 deletions

View File

@@ -312,14 +312,16 @@ void Buffer::sync(const std::vector<int> &device_ids,
memory_ids[r] = proxy_service->addMemory(std::move(mem));
}
// Rank -> vector of connections
// Rank -> vector of connections. Skip self: the kernel's same-rank
// path uses a direct warp copy (see internode_ll.cu `dst_rank != rank`
// check) and never dereferences the self-slot port channel. Creating
// a self CUDA-IPC connection + self semaphore previously skewed the
// cross-rank `buildAndAddSemaphore` handshake sequence between ranks,
// leading to asymmetric semaphore pairings that prevented atomicAdd
// signals from being delivered in one direction during LL dispatch.
std::unordered_map<int, std::vector<mscclpp::Connection>> connections;
const mscclpp::EndpointConfig ipc_cfg(ipc_transport);
const mscclpp::EndpointConfig ib_cfg(ib_transport);
// Self connection for local memory (CUDA IPC).
connections[rank].emplace_back(communicator->connect(ipc_cfg, rank, kRdmaTag).get());
// Remote IB connections (multi-QP per peer).
const int num_ib_connections_per_rank = 12; // #QPs per rank (mirrors DeepEP).
for (int r = 0; r < num_ranks; ++r) {
@@ -333,11 +335,14 @@ void Buffer::sync(const std::vector<int> &device_ids,
}
// Rank -> vector of semaphore IDs. Iterate peers in sorted rank order so
// semaphore pairings between nodes line up deterministically.
// semaphore pairings between nodes line up deterministically. Self is
// skipped so both sides see an identical sequence of cross-rank
// `buildAndAddSemaphore` calls.
std::unordered_map<int, std::vector<mscclpp::SemaphoreId>> sema_ids;
const int num_semaphores_per_rank = 16;
for (int i = 0; i < num_semaphores_per_rank; ++i) {
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) continue;
auto conn_it = connections.find(r);
EP_HOST_ASSERT(conn_it != connections.end());
auto& conns = conn_it->second;
@@ -351,12 +356,17 @@ void Buffer::sync(const std::vector<int> &device_ids,
//
// The kernels index `port_channel_handles[channel_id * num_ranks + peer_rank]`
// where peer_rank is a GLOBAL rank in [0..num_ranks). So the outer stride must
// be num_ranks with peers in ascending rank order. Iterating `memory_ids` (an
// `unordered_map`) yields hash order and would misroute signals, deadlocking.
// be num_ranks with peers in ascending rank order. The self slot is filled
// with a zero-initialized placeholder handle that the kernels never touch.
const int num_port_channels_per_rank = num_semaphores_per_rank;
std::vector<mscclpp::PortChannelDeviceHandle> port_channel_handles;
for (int i = 0; i < num_port_channels_per_rank; ++i) {
for (int r = 0; r < num_ranks; ++r) {
if (r == rank) {
// Placeholder; indexed but never dispatched by kernels.
port_channel_handles.emplace_back(mscclpp::PortChannelDeviceHandle{});
continue;
}
auto mem_it = memory_ids.find(r);
EP_HOST_ASSERT(mem_it != memory_ids.end());
auto memory_id = mem_it->second;

View File

@@ -96,16 +96,20 @@ struct LowLatencyBuffer {
void* dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr;
int* dispatch_rdma_recv_count_buffer = nullptr;
// NOTE: signaling buffers are int64_t (not int) so that IB atomic ops
// (IBV_WR_ATOMIC_FETCH_AND_ADD is a 64-bit, 8-byte-aligned op) always
// target an 8-byte-aligned address. Using int32 slots produced unaligned
// atomics at odd indices that the NIC silently drops.
int64_t* dispatch_rdma_recv_count_buffer = nullptr;
void* combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr;
int* combine_rdma_recv_flag_buffer = nullptr;
int64_t* combine_rdma_recv_flag_buffer = nullptr;
void* combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0;
std::pair<int*, int> clean_meta() {
std::pair<int64_t*, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
}
@@ -149,8 +153,8 @@ struct LowLatencyLayout {
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
// Symmetric signaling buffers (int64_t slots for 8-byte-aligned IB atomics).
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
total_bytes += signaling_buffer_bytes * 2;
@@ -160,13 +164,13 @@ struct LowLatencyLayout {
// so you may see some parameters are duplicated
for (int i = 0; i < 2; ++ i) {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
static_cast<int>(signaling_buffer_bytes / sizeof(int64_t)),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * i),
num_bytes_per_combine_msg
};

View File

@@ -135,8 +135,8 @@ void combine(cudaDataType_t type,
// ===========================================================================
namespace internode_ll {
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
cudaStream_t stream);
@@ -144,9 +144,9 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
void* workspace, cudaStream_t stream, int phases,
@@ -154,10 +154,10 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
mscclpp::PortChannelDeviceHandle* port_channel_handles);
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream,

View File

@@ -71,8 +71,8 @@ __device__ __forceinline__ void port_channel_barrier_block(
// ---------------------------------------------------------------------------
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
__global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
int rank, int num_ranks) {
// Barrier before cleaning (in case of unfinished chunked EP)
@@ -91,8 +91,8 @@ __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
port_channel_barrier_block(port_channel_handles, rank, num_ranks);
}
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks,
mscclpp::PortChannelDeviceHandle* port_channel_handles,
cudaStream_t stream) {
@@ -113,10 +113,10 @@ __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int* next_clean, int num_next_clean_int,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int phases,
@@ -284,7 +284,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
.atomicAdd(off, static_cast<int64_t>(-num_tokens_sent - 1));
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
static_cast<int64_t>(-num_tokens_sent - 1));
}
atomic_counter_per_expert[responsible_expert_idx] = 0;
@@ -320,8 +321,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int num_recv_tokens, recv_token_begin_idx;
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
if (sub_warp_id == 1 and lane_id == 0) {
while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
num_recv_tokens = -num_recv_tokens - 1;
int64_t raw;
while ((raw = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
num_recv_tokens = static_cast<int>(-raw - 1);
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
@@ -358,9 +360,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
void* workspace, cudaStream_t stream, int phases,
@@ -406,10 +408,10 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int64_t* next_clean, int num_next_clean_int,
int* atomic_clean_flag,
int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank,
@@ -496,7 +498,7 @@ combine(void* combined_x,
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
.atomicAdd(off, static_cast<int64_t>(1));
} else {
st_na_release(rdma_recv_flag + global_expert_idx, 1);
st_na_release(rdma_recv_flag + global_expert_idx, static_cast<int64_t>(1));
}
atomic_add_release_global(atomic_clean_flag, -1);
}
@@ -550,10 +552,10 @@ combine(void* combined_x,
}
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream,

View File

@@ -70,6 +70,12 @@ __device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) {
return ret;
}
__device__ __forceinline__ int64_t ld_acquire_sys_global(const int64_t *ptr) {
int64_t ret;
asm volatile("ld.acquire.sys.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int ld_acquire_global(const int *ptr) {
int ret;
asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
@@ -232,6 +238,10 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val)
asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val));
}
__device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) {
asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val));
}
// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
#define ST_NA_FUNC "st.global.L1::no_allocate"