diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index 98746abb..6946f63e 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -118,6 +118,18 @@ Buffer::~Buffer() noexcept(false) { // failed, so there is nothing to tear down. } + // Intra-node LL IPC fast-path teardown. + if (ll_ipc_ready) { + for (int i = 0; i < num_ranks; ++i) { + if (i == rank or peer_rdma_bases[i] == nullptr) continue; + CUDA_CHECK(cudaIpcCloseMemHandle(peer_rdma_bases[i])); + } + if (peer_rdma_bases_gpu != nullptr) { + CUDA_CHECK(cudaFree(peer_rdma_bases_gpu)); + peer_rdma_bases_gpu = nullptr; + } + } + proxy_service->stopProxy(); // Free cuBLAS handle, workspace and MoE counter @@ -372,6 +384,80 @@ void Buffer::sync(const std::vector &device_ids, mscclpp::gpuMemcpy( port_channel_handles_device_ptr.get(), port_channel_handles.data(), port_channel_handles.size(), cudaMemcpyHostToDevice); + + // ------------------------------------------------------------------ + // Intra-node LL fast path setup. + // + // When all ranks sit on the same host (num_rdma_ranks == 1), LL dispatch + // and combine still go through `PortChannel` above — which internally + // uses the proxy service over IB loopback between different HCAs on + // this platform. That path is correct but slow (caps at ~170 GB/s vs. + // NVLink's multi-TB/s). We additionally set up CUDA-IPC peer pointers + // to each peer's `rdma_buffer_ptr` plus a set of per-peer MemoryChannels + // for a barrier ring. The LL kernels select this path at launch time. + // Cross-node LL is unaffected: this block is a no-op there. + // ------------------------------------------------------------------ + if (low_latency_mode and num_rdma_ranks == 1) { + EP_HOST_ASSERT(num_ranks == num_nvl_ranks); + EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS); + + // 1. Exchange CUDA IPC handles for rdma_buffer_ptr via bootstrap. + CUDA_CHECK(cudaIpcGetMemHandle(&rdma_ipc_handles[rank], rdma_buffer_ptr)); + std::vector all_rdma_handles(num_ranks); + all_rdma_handles[rank] = rdma_ipc_handles[rank]; + bootstrap->allGather(all_rdma_handles.data(), sizeof(cudaIpcMemHandle_t)); + + peer_rdma_bases[rank] = rdma_buffer_ptr; + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + rdma_ipc_handles[r] = all_rdma_handles[r]; + CUDA_CHECK(cudaIpcOpenMemHandle(&peer_rdma_bases[r], rdma_ipc_handles[r], + cudaIpcMemLazyEnablePeerAccess)); + } + CUDA_CHECK(cudaMalloc(&peer_rdma_bases_gpu, sizeof(void*) * NUM_MAX_NVL_PEERS)); + CUDA_CHECK(cudaMemcpy(peer_rdma_bases_gpu, peer_rdma_bases, + sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + + // 2. Build MemoryChannels for the per-peer barrier ring. These use + // CUDA IPC connections (distinct tag from the existing port-channel + // machinery) so setup does not interfere with cross-node fallback. + constexpr int kLlIpcTag = 2; + auto rdma_mem_ipc = communicator->registerMemory(rdma_buffer_ptr, num_rdma_bytes, ipc_transport); + std::vector> remote_futures(num_ranks); + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + communicator->sendMemory(rdma_mem_ipc, r, kLlIpcTag); + remote_futures[r] = communicator->recvMemory(r, kLlIpcTag); + } + std::vector ll_ipc_conns(num_ranks); + { + std::vector> conn_futures(num_ranks); + mscclpp::EndpointConfig cfg(ipc_transport); + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + conn_futures[r] = communicator->connect(cfg, r, kLlIpcTag); + } + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + ll_ipc_conns[r] = conn_futures[r].get(); + } + } + + std::vector ll_handles(num_ranks); + for (int r = 0; r < num_ranks; ++r) { + if (r == rank) continue; + auto sema = std::make_shared(*communicator, ll_ipc_conns[r]); + ll_memory_channels.emplace_back(sema, remote_futures[r].get(), rdma_mem_ipc); + ll_handles[r] = ll_memory_channels.rbegin()->deviceHandle(); + } + ll_memory_channel_handles_device_ptr = + mscclpp::detail::gpuCallocShared(num_ranks); + mscclpp::gpuMemcpy( + ll_memory_channel_handles_device_ptr.get(), ll_handles.data(), num_ranks, + cudaMemcpyHostToDevice); + + ll_ipc_ready = true; + } } // Ready to use @@ -1175,6 +1261,9 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int clean_meta_1.first, clean_meta_1.second, rank, num_ranks, port_channel_handles_device_ptr.get(), + ll_memory_channel_handles_device_ptr ? + ll_memory_channel_handles_device_ptr.get() : nullptr, + ll_ipc_ready, at::cuda::getCurrentCUDAStream()); } @@ -1223,6 +1312,10 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i auto next_clean_meta = next_buffer.clean_meta(); auto port_handles = port_channel_handles_device_ptr.get(); + auto mem_handles = ll_memory_channel_handles_device_ptr ? + ll_memory_channel_handles_device_ptr.get() : nullptr; + auto peer_bases = peer_rdma_bases_gpu; + const bool use_ipc = ll_ipc_ready; auto rdma_base = rdma_buffer_ptr; auto launcher = [=](int phases) { internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, @@ -1235,7 +1328,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, workspace, launch_stream, phases, - rdma_base, port_handles); + rdma_base, port_handles, + peer_bases, mem_handles, use_ipc); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); @@ -1303,6 +1397,10 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id auto next_clean_meta = next_buffer.clean_meta(); auto port_handles = port_channel_handles_device_ptr.get(); + auto mem_handles = ll_memory_channel_handles_device_ptr ? + ll_memory_channel_handles_device_ptr.get() : nullptr; + auto peer_bases = peer_rdma_bases_gpu; + const bool use_ipc = ll_ipc_ready; auto rdma_base = rdma_buffer_ptr; auto launcher = [=](int phases) { internode_ll::combine(combined_x.data_ptr(), @@ -1315,7 +1413,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id num_topk, num_experts, rank, num_ranks, workspace, launch_stream, phases, zero_copy, - rdma_base, port_handles); + rdma_base, port_handles, + peer_bases, mem_handles, use_ipc); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); diff --git a/src/ext/ep/buffer.hpp b/src/ext/ep/buffer.hpp index 1b0399ef..625c4074 100644 --- a/src/ext/ep/buffer.hpp +++ b/src/ext/ep/buffer.hpp @@ -82,6 +82,18 @@ private: std::shared_ptr port_channel_handles_device_ptr; std::shared_ptr memory_channel_handles_device_ptr; + // Intra-node LL only: peer-mapped RDMA buffer pointers (CUDA IPC). + // ``peer_rdma_bases[r]`` aliases rank ``r``'s ``rdma_buffer_ptr`` via + // ``cudaIpcOpenMemHandle`` (lazy peer access). Populated in ``sync()`` when + // ``low_latency_mode && num_rdma_ranks == 1``; null otherwise. + cudaIpcMemHandle_t rdma_ipc_handles[NUM_MAX_NVL_PEERS]; + void* peer_rdma_bases[NUM_MAX_NVL_PEERS] = {nullptr}; + void** peer_rdma_bases_gpu = nullptr; + // MemoryChannels over CUDA IPC used only for the LL barrier ring. + std::vector ll_memory_channels; + std::shared_ptr ll_memory_channel_handles_device_ptr; + bool ll_ipc_ready = false; + private: void move_fifo_slots(int num_slots = 1); diff --git a/src/ext/ep/kernels/api.cuh b/src/ext/ep/kernels/api.cuh index 2b54e8e6..7647cb97 100644 --- a/src/ext/ep/kernels/api.cuh +++ b/src/ext/ep/kernels/api.cuh @@ -139,6 +139,8 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, int64_t* clean_1, int num_clean_int_1, int rank, int num_ranks, mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + bool use_ipc_path, cudaStream_t stream); void dispatch(void* packed_recv_x, float* packed_recv_x_scales, @@ -151,7 +153,10 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles); + mscclpp::PortChannelDeviceHandle* port_channel_handles, + void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + bool use_ipc_path); void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, @@ -163,7 +168,10 @@ void combine(void* combined_x, void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles); + mscclpp::PortChannelDeviceHandle* port_channel_handles, + void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + bool use_ipc_path); } // namespace internode_ll diff --git a/src/ext/ep/kernels/internode_ll.cu b/src/ext/ep/kernels/internode_ll.cu index 0a915aab..e4ffffb7 100644 --- a/src/ext/ep/kernels/internode_ll.cu +++ b/src/ext/ep/kernels/internode_ll.cu @@ -33,6 +33,7 @@ #include #include +#include namespace cg = cooperative_groups; @@ -52,6 +53,15 @@ __device__ __forceinline__ uint64_t rdma_offset_of(uint64_t ptr, void* rdma_buff return ptr - reinterpret_cast(rdma_buffer_ptr); } +// Translate a local pointer aliased into our own rdma_buffer_ptr into the +// peer-mapped pointer for rank `peer_rank`. Only valid when the IPC fast path +// is active (`peer_rdma_bases != nullptr`). +__device__ __forceinline__ uint64_t peer_ptr_of(uint64_t local_ptr, void* const* peer_rdma_bases, + void* rdma_buffer_ptr, int peer_rank) { + const auto off = local_ptr - reinterpret_cast(rdma_buffer_ptr); + return reinterpret_cast(peer_rdma_bases[peer_rank]) + off; +} + // Cross-rank barrier via port-channel signal/wait ring. // Uses port channel `qp=0` across all connected peers. __device__ __forceinline__ void port_channel_barrier_block( @@ -66,17 +76,43 @@ __device__ __forceinline__ void port_channel_barrier_block( __syncthreads(); } +// Same barrier but using MemoryChannels (CUDA IPC, NVLink). Used on the +// intra-node LL fast path. +__device__ __forceinline__ void memory_channel_barrier_block( + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + int rank, int num_ranks) { + const int tid = threadIdx.x; + if (tid < num_ranks && tid != rank) { + memory_channel_handles[tid].signal(); + memory_channel_handles[tid].wait(); + } + __syncthreads(); +} + +template +__device__ __forceinline__ void ll_barrier_block( + mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + int rank, int num_ranks) { + if constexpr (kIpcPath) { + memory_channel_barrier_block(memory_channel_handles, rank, num_ranks); + } else { + port_channel_barrier_block(port_channel_handles, rank, num_ranks); + } +} + // --------------------------------------------------------------------------- // clean_low_latency_buffer // --------------------------------------------------------------------------- -template __launch_bounds__(kNumThreads, 1) +template __launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, int64_t* clean_1, int num_clean_int_1, mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, int rank, int num_ranks) { // Barrier before cleaning (in case of unfinished chunked EP) - port_channel_barrier_block(port_channel_handles, rank, num_ranks); + ll_barrier_block(port_channel_handles, memory_channel_handles, rank, num_ranks); // Clean auto thread_id = static_cast(threadIdx.x); @@ -88,27 +124,35 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, clean_1[i] = 0; // Barrier after cleaning (make sure low-latency mode work fine) - port_channel_barrier_block(port_channel_handles, rank, num_ranks); + ll_barrier_block(port_channel_handles, memory_channel_handles, rank, num_ranks); } void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, int64_t* clean_1, int num_clean_int_1, int rank, int num_ranks, mscclpp::PortChannelDeviceHandle* port_channel_handles, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + bool use_ipc_path, cudaStream_t stream) { constexpr int kNumThreads = 256; SETUP_LAUNCH_CONFIG(1, kNumThreads, stream); - LAUNCH_KERNEL(&cfg, clean_low_latency_buffer, - clean_0, num_clean_int_0, clean_1, num_clean_int_1, - port_channel_handles, rank, num_ranks); + if (use_ipc_path) { + LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer), + clean_0, num_clean_int_0, clean_1, num_clean_int_1, + port_channel_handles, memory_channel_handles, rank, num_ranks); + } else { + LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer), + clean_0, num_clean_int_0, clean_1, num_clean_int_1, + port_channel_handles, memory_channel_handles, rank, num_ranks); + } } // --------------------------------------------------------------------------- // dispatch // --------------------------------------------------------------------------- -template +template __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -121,7 +165,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, int num_topk, int num_experts, int rank, int num_ranks, int phases, void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles) { + mscclpp::PortChannelDeviceHandle* port_channel_handles, + void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { const auto sm_id = static_cast(blockIdx.x); const auto thread_id = static_cast(threadIdx.x); const auto warp_id = thread_id / 32, lane_id = get_lane_id(); @@ -214,14 +260,22 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; if (dst_rank != rank) { - // MSCCL++ port-channel PUT (lane 0 issues one request). - if (lane_id == 0) { - const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); - const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr); - port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank] - .put(dst_off, src_off, num_bytes_per_msg); + if constexpr (kIpcPath) { + // Peer-mapped warp copy over NVLink (CUDA IPC). + const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank); + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(peer_dst); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } else { + // MSCCL++ port-channel PUT (lane 0 issues one request). + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr); + port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank] + .put(dst_off, src_off, num_bytes_per_msg); + } + __syncwarp(); } - __syncwarp(); } else { const auto* src_int4_ptr = reinterpret_cast(src_ptr); const auto* dst_int4_ptr = reinterpret_cast(dst_ptr); @@ -279,10 +333,19 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); if (dst_rank != rank) { - auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank; - const auto off = rdma_offset_of(reinterpret_cast(counter_ptr), rdma_buffer_ptr); - port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank] - .atomicAdd(off, static_cast(-num_tokens_sent - 1)); + if constexpr (kIpcPath) { + // Single writer per (dst_expert_local_idx, rank) slot, so a + // release-ordered store on the peer-mapped counter is enough. + auto peer_counter = reinterpret_cast(peer_ptr_of( + reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), + peer_rdma_bases, rdma_buffer_ptr, dst_rank)); + st_na_release(peer_counter, static_cast(-num_tokens_sent - 1)); + } else { + auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank; + const auto off = rdma_offset_of(reinterpret_cast(counter_ptr), rdma_buffer_ptr); + port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank] + .atomicAdd(off, static_cast(-num_tokens_sent - 1)); + } } else { st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, static_cast(-num_tokens_sent - 1)); @@ -367,7 +430,10 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases, void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles) { + mscclpp::PortChannelDeviceHandle* port_channel_handles, + void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + bool use_ipc_path) { constexpr int kNumMaxTopK = 9; constexpr int kNumWarpsPerGroup = 10; constexpr int kNumWarpGroups = 3; @@ -382,19 +448,37 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); #define DISPATCH_LAUNCH_CASE(hidden_case) { \ -auto dispatch_func = use_fp8 ? dispatch : \ - dispatch; \ -LAUNCH_KERNEL(&cfg, dispatch_func, \ - packed_recv_x, packed_recv_x_scales, \ - packed_recv_src_info, packed_recv_layout_range, \ - packed_recv_count, \ - rdma_recv_x, rdma_recv_count, rdma_x, \ - x, topk_idx, \ - atomic_counter_per_expert, atomic_finish_counter_per_expert, \ - next_clean, num_next_clean_int, \ - num_tokens, num_max_dispatch_tokens_per_rank, \ - num_topk, num_experts, rank, num_ranks, phases, \ - rdma_buffer_ptr, port_channel_handles); } break +if (use_ipc_path) { \ + auto dispatch_func = use_fp8 ? dispatch \ + : dispatch; \ + LAUNCH_KERNEL(&cfg, dispatch_func, \ + packed_recv_x, packed_recv_x_scales, \ + packed_recv_src_info, packed_recv_layout_range, \ + packed_recv_count, \ + rdma_recv_x, rdma_recv_count, rdma_x, \ + x, topk_idx, \ + atomic_counter_per_expert, atomic_finish_counter_per_expert, \ + next_clean, num_next_clean_int, \ + num_tokens, num_max_dispatch_tokens_per_rank, \ + num_topk, num_experts, rank, num_ranks, phases, \ + rdma_buffer_ptr, port_channel_handles, \ + peer_rdma_bases, memory_channel_handles); \ +} else { \ + auto dispatch_func = use_fp8 ? dispatch \ + : dispatch; \ + LAUNCH_KERNEL(&cfg, dispatch_func, \ + packed_recv_x, packed_recv_x_scales, \ + packed_recv_src_info, packed_recv_layout_range, \ + packed_recv_count, \ + rdma_recv_x, rdma_recv_count, rdma_x, \ + x, topk_idx, \ + atomic_counter_per_expert, atomic_finish_counter_per_expert, \ + next_clean, num_next_clean_int, \ + num_tokens, num_max_dispatch_tokens_per_rank, \ + num_topk, num_experts, rank, num_ranks, phases, \ + rdma_buffer_ptr, port_channel_handles, \ + peer_rdma_bases, memory_channel_handles); \ +} } break SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE); @@ -405,7 +489,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \ // combine // --------------------------------------------------------------------------- -template +template __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void combine(void* combined_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, @@ -418,7 +502,9 @@ combine(void* combined_x, int num_experts, int rank, int num_ranks, int phases, bool zero_copy, void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles) { + mscclpp::PortChannelDeviceHandle* port_channel_handles, + void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) { const auto sm_id = static_cast(blockIdx.x); const auto num_sms = static_cast(gridDim.x); const auto thread_id = static_cast(threadIdx.x); @@ -474,17 +560,25 @@ combine(void* combined_x, const auto dst_int4_ptr = reinterpret_cast(dst_ptr); UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); } else { - const auto buf_int4_ptr = reinterpret_cast(buf_ptr); - if (not zero_copy) - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); - // MSCCL++ port-channel PUT. - if (lane_id == 0) { - const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); - const auto src_off = rdma_offset_of(static_cast(buf_ptr), rdma_buffer_ptr); - port_channel_handles[local_expert_idx * num_ranks + dst_rank] - .put(dst_off, src_off, hidden * sizeof(nv_bfloat16)); + if constexpr (kIpcPath) { + // Peer-mapped warp copy over NVLink. `zero_copy` is irrelevant + // on this path because we skip the rdma_send staging buffer. + const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank); + const auto peer_dst_int4 = reinterpret_cast(peer_dst); + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, peer_dst_int4, x_int4, ld_nc_global, st_na_global); + } else { + const auto buf_int4_ptr = reinterpret_cast(buf_ptr); + if (not zero_copy) + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); + // MSCCL++ port-channel PUT. + if (lane_id == 0) { + const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr); + const auto src_off = rdma_offset_of(static_cast(buf_ptr), rdma_buffer_ptr); + port_channel_handles[local_expert_idx * num_ranks + dst_rank] + .put(dst_off, src_off, hidden * sizeof(nv_bfloat16)); + } + __syncwarp(); } - __syncwarp(); } } @@ -493,10 +587,17 @@ combine(void* combined_x, if (sub_warp_id == 1 and lane_id == 0) { while (ld_acquire_global(atomic_clean_flag) == 0); if (dst_rank != rank) { - auto* flag_ptr = rdma_recv_flag + global_expert_idx; - const auto off = rdma_offset_of(reinterpret_cast(flag_ptr), rdma_buffer_ptr); - port_channel_handles[local_expert_idx * num_ranks + dst_rank] - .atomicAdd(off, static_cast(1)); + if constexpr (kIpcPath) { + auto peer_flag = reinterpret_cast(peer_ptr_of( + reinterpret_cast(rdma_recv_flag + global_expert_idx), + peer_rdma_bases, rdma_buffer_ptr, dst_rank)); + st_na_release(peer_flag, static_cast(1)); + } else { + auto* flag_ptr = rdma_recv_flag + global_expert_idx; + const auto off = rdma_offset_of(reinterpret_cast(flag_ptr), rdma_buffer_ptr); + port_channel_handles[local_expert_idx * num_ranks + dst_rank] + .atomicAdd(off, static_cast(1)); + } } else { st_na_release(rdma_recv_flag + global_expert_idx, static_cast(1)); } @@ -561,7 +662,10 @@ void combine(void* combined_x, void* workspace, cudaStream_t stream, int phases, bool zero_copy, void* rdma_buffer_ptr, - mscclpp::PortChannelDeviceHandle* port_channel_handles) { + mscclpp::PortChannelDeviceHandle* port_channel_handles, + void* const* peer_rdma_bases, + mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, + bool use_ipc_path) { constexpr int kNumWarpsPerGroup = 10; constexpr int kNumWarpGroups = 3; constexpr int kNumMaxTopk = 9; @@ -574,18 +678,35 @@ void combine(void* combined_x, EP_HOST_ASSERT(num_topk <= kNumMaxTopk); #define COMBINE_LAUNCH_CASE(hidden_case) { \ -auto combine_func = combine; \ -LAUNCH_KERNEL(&cfg, combine_func, \ - combined_x, \ - rdma_recv_x, rdma_recv_flag, rdma_send_x, \ - x, topk_idx, topk_weights, src_info, layout_range, \ - next_clean, num_next_clean_int, \ - atomic_clean_flag, \ - num_combined_tokens, hidden, num_topk, \ - num_max_dispatch_tokens_per_rank, \ - num_experts, rank, num_ranks, \ - phases, zero_copy, \ - rdma_buffer_ptr, port_channel_handles); } break +if (use_ipc_path) { \ + auto combine_func = combine; \ + LAUNCH_KERNEL(&cfg, combine_func, \ + combined_x, \ + rdma_recv_x, rdma_recv_flag, rdma_send_x, \ + x, topk_idx, topk_weights, src_info, layout_range, \ + next_clean, num_next_clean_int, \ + atomic_clean_flag, \ + num_combined_tokens, hidden, num_topk, \ + num_max_dispatch_tokens_per_rank, \ + num_experts, rank, num_ranks, \ + phases, zero_copy, \ + rdma_buffer_ptr, port_channel_handles, \ + peer_rdma_bases, memory_channel_handles); \ +} else { \ + auto combine_func = combine; \ + LAUNCH_KERNEL(&cfg, combine_func, \ + combined_x, \ + rdma_recv_x, rdma_recv_flag, rdma_send_x, \ + x, topk_idx, topk_weights, src_info, layout_range, \ + next_clean, num_next_clean_int, \ + atomic_clean_flag, \ + num_combined_tokens, hidden, num_topk, \ + num_max_dispatch_tokens_per_rank, \ + num_experts, rank, num_ranks, \ + phases, zero_copy, \ + rdma_buffer_ptr, port_channel_handles, \ + peer_rdma_bases, memory_channel_handles); \ +} } break SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);