mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
ext/ep: LL intra-node fast path via CUDA IPC + MemoryChannel
When all ranks live on the same host (num_rdma_ranks == 1), the LL
kernels now bypass PortChannel/IB-loopback entirely. In Buffer::sync()
we additionally:
- allGather IPC handles for each rank's rdma_buffer_ptr and
cudaIpcOpenMemHandle them into peer_rdma_bases[]
- build per-peer MemoryChannels over CUDA IPC connections (tag=2)
used only for the LL barrier ring
The three LL kernels (clean / dispatch / combine) gain a kIpcPath
template parameter and two extra args (peer_rdma_bases,
memory_channel_handles). At each peer op:
- put -> peer-mapped warp copy over NVLink
- atomicAdd-like flag store -> single-writer st_na_release on peer ptr
- signal/wait barrier -> MemoryChannel signal/wait
Cross-node LL (num_rdma_ranks > 1) is untouched; the IPC setup block is
a no-op. The host launch wrappers select the variant via use_ipc_path.
This commit is contained in:
@@ -118,6 +118,18 @@ Buffer::~Buffer() noexcept(false) {
|
||||
// failed, so there is nothing to tear down.
|
||||
}
|
||||
|
||||
// Intra-node LL IPC fast-path teardown.
|
||||
if (ll_ipc_ready) {
|
||||
for (int i = 0; i < num_ranks; ++i) {
|
||||
if (i == rank or peer_rdma_bases[i] == nullptr) continue;
|
||||
CUDA_CHECK(cudaIpcCloseMemHandle(peer_rdma_bases[i]));
|
||||
}
|
||||
if (peer_rdma_bases_gpu != nullptr) {
|
||||
CUDA_CHECK(cudaFree(peer_rdma_bases_gpu));
|
||||
peer_rdma_bases_gpu = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
proxy_service->stopProxy();
|
||||
|
||||
// Free cuBLAS handle, workspace and MoE counter
|
||||
@@ -372,6 +384,80 @@ void Buffer::sync(const std::vector<int> &device_ids,
|
||||
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
|
||||
port_channel_handles_device_ptr.get(), port_channel_handles.data(), port_channel_handles.size(),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Intra-node LL fast path setup.
|
||||
//
|
||||
// When all ranks sit on the same host (num_rdma_ranks == 1), LL dispatch
|
||||
// and combine still go through `PortChannel` above — which internally
|
||||
// uses the proxy service over IB loopback between different HCAs on
|
||||
// this platform. That path is correct but slow (caps at ~170 GB/s vs.
|
||||
// NVLink's multi-TB/s). We additionally set up CUDA-IPC peer pointers
|
||||
// to each peer's `rdma_buffer_ptr` plus a set of per-peer MemoryChannels
|
||||
// for a barrier ring. The LL kernels select this path at launch time.
|
||||
// Cross-node LL is unaffected: this block is a no-op there.
|
||||
// ------------------------------------------------------------------
|
||||
if (low_latency_mode and num_rdma_ranks == 1) {
|
||||
EP_HOST_ASSERT(num_ranks == num_nvl_ranks);
|
||||
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS);
|
||||
|
||||
// 1. Exchange CUDA IPC handles for rdma_buffer_ptr via bootstrap.
|
||||
CUDA_CHECK(cudaIpcGetMemHandle(&rdma_ipc_handles[rank], rdma_buffer_ptr));
|
||||
std::vector<cudaIpcMemHandle_t> all_rdma_handles(num_ranks);
|
||||
all_rdma_handles[rank] = rdma_ipc_handles[rank];
|
||||
bootstrap->allGather(all_rdma_handles.data(), sizeof(cudaIpcMemHandle_t));
|
||||
|
||||
peer_rdma_bases[rank] = rdma_buffer_ptr;
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
if (r == rank) continue;
|
||||
rdma_ipc_handles[r] = all_rdma_handles[r];
|
||||
CUDA_CHECK(cudaIpcOpenMemHandle(&peer_rdma_bases[r], rdma_ipc_handles[r],
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
}
|
||||
CUDA_CHECK(cudaMalloc(&peer_rdma_bases_gpu, sizeof(void*) * NUM_MAX_NVL_PEERS));
|
||||
CUDA_CHECK(cudaMemcpy(peer_rdma_bases_gpu, peer_rdma_bases,
|
||||
sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
|
||||
|
||||
// 2. Build MemoryChannels for the per-peer barrier ring. These use
|
||||
// CUDA IPC connections (distinct tag from the existing port-channel
|
||||
// machinery) so setup does not interfere with cross-node fallback.
|
||||
constexpr int kLlIpcTag = 2;
|
||||
auto rdma_mem_ipc = communicator->registerMemory(rdma_buffer_ptr, num_rdma_bytes, ipc_transport);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remote_futures(num_ranks);
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
if (r == rank) continue;
|
||||
communicator->sendMemory(rdma_mem_ipc, r, kLlIpcTag);
|
||||
remote_futures[r] = communicator->recvMemory(r, kLlIpcTag);
|
||||
}
|
||||
std::vector<mscclpp::Connection> ll_ipc_conns(num_ranks);
|
||||
{
|
||||
std::vector<std::shared_future<mscclpp::Connection>> conn_futures(num_ranks);
|
||||
mscclpp::EndpointConfig cfg(ipc_transport);
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
if (r == rank) continue;
|
||||
conn_futures[r] = communicator->connect(cfg, r, kLlIpcTag);
|
||||
}
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
if (r == rank) continue;
|
||||
ll_ipc_conns[r] = conn_futures[r].get();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> ll_handles(num_ranks);
|
||||
for (int r = 0; r < num_ranks; ++r) {
|
||||
if (r == rank) continue;
|
||||
auto sema = std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*communicator, ll_ipc_conns[r]);
|
||||
ll_memory_channels.emplace_back(sema, remote_futures[r].get(), rdma_mem_ipc);
|
||||
ll_handles[r] = ll_memory_channels.rbegin()->deviceHandle();
|
||||
}
|
||||
ll_memory_channel_handles_device_ptr =
|
||||
mscclpp::detail::gpuCallocShared<mscclpp::MemoryChannelDeviceHandle>(num_ranks);
|
||||
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
|
||||
ll_memory_channel_handles_device_ptr.get(), ll_handles.data(), num_ranks,
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
ll_ipc_ready = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Ready to use
|
||||
@@ -1175,6 +1261,9 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
||||
clean_meta_1.first, clean_meta_1.second,
|
||||
rank, num_ranks,
|
||||
port_channel_handles_device_ptr.get(),
|
||||
ll_memory_channel_handles_device_ptr ?
|
||||
ll_memory_channel_handles_device_ptr.get() : nullptr,
|
||||
ll_ipc_ready,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
|
||||
@@ -1223,6 +1312,10 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
|
||||
auto next_clean_meta = next_buffer.clean_meta();
|
||||
auto port_handles = port_channel_handles_device_ptr.get();
|
||||
auto mem_handles = ll_memory_channel_handles_device_ptr ?
|
||||
ll_memory_channel_handles_device_ptr.get() : nullptr;
|
||||
auto peer_bases = peer_rdma_bases_gpu;
|
||||
const bool use_ipc = ll_ipc_ready;
|
||||
auto rdma_base = rdma_buffer_ptr;
|
||||
auto launcher = [=](int phases) {
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
|
||||
@@ -1235,7 +1328,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||
num_topk, num_experts, rank, num_ranks, use_fp8,
|
||||
workspace, launch_stream, phases,
|
||||
rdma_base, port_handles);
|
||||
rdma_base, port_handles,
|
||||
peer_bases, mem_handles, use_ipc);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
@@ -1303,6 +1397,10 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
|
||||
auto next_clean_meta = next_buffer.clean_meta();
|
||||
auto port_handles = port_channel_handles_device_ptr.get();
|
||||
auto mem_handles = ll_memory_channel_handles_device_ptr ?
|
||||
ll_memory_channel_handles_device_ptr.get() : nullptr;
|
||||
auto peer_bases = peer_rdma_bases_gpu;
|
||||
const bool use_ipc = ll_ipc_ready;
|
||||
auto rdma_base = rdma_buffer_ptr;
|
||||
auto launcher = [=](int phases) {
|
||||
internode_ll::combine(combined_x.data_ptr(),
|
||||
@@ -1315,7 +1413,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
num_topk, num_experts, rank, num_ranks,
|
||||
workspace, launch_stream,
|
||||
phases, zero_copy,
|
||||
rdma_base, port_handles);
|
||||
rdma_base, port_handles,
|
||||
peer_bases, mem_handles, use_ipc);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
|
||||
@@ -82,6 +82,18 @@ private:
|
||||
std::shared_ptr<mscclpp::PortChannelDeviceHandle> port_channel_handles_device_ptr;
|
||||
std::shared_ptr<mscclpp::MemoryChannelDeviceHandle> memory_channel_handles_device_ptr;
|
||||
|
||||
// Intra-node LL only: peer-mapped RDMA buffer pointers (CUDA IPC).
|
||||
// ``peer_rdma_bases[r]`` aliases rank ``r``'s ``rdma_buffer_ptr`` via
|
||||
// ``cudaIpcOpenMemHandle`` (lazy peer access). Populated in ``sync()`` when
|
||||
// ``low_latency_mode && num_rdma_ranks == 1``; null otherwise.
|
||||
cudaIpcMemHandle_t rdma_ipc_handles[NUM_MAX_NVL_PEERS];
|
||||
void* peer_rdma_bases[NUM_MAX_NVL_PEERS] = {nullptr};
|
||||
void** peer_rdma_bases_gpu = nullptr;
|
||||
// MemoryChannels over CUDA IPC used only for the LL barrier ring.
|
||||
std::vector<mscclpp::MemoryChannel> ll_memory_channels;
|
||||
std::shared_ptr<mscclpp::MemoryChannelDeviceHandle> ll_memory_channel_handles_device_ptr;
|
||||
bool ll_ipc_ready = false;
|
||||
|
||||
private:
|
||||
void move_fifo_slots(int num_slots = 1);
|
||||
|
||||
|
||||
@@ -139,6 +139,8 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
|
||||
int64_t* clean_1, int num_clean_int_1,
|
||||
int rank, int num_ranks,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
bool use_ipc_path,
|
||||
cudaStream_t stream);
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
@@ -151,7 +153,10 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
void* workspace, cudaStream_t stream, int phases,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles);
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
bool use_ipc_path);
|
||||
|
||||
void combine(void* combined_x,
|
||||
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
|
||||
@@ -163,7 +168,10 @@ void combine(void* combined_x,
|
||||
void* workspace, cudaStream_t stream,
|
||||
int phases, bool zero_copy,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles);
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
bool use_ipc_path);
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
@@ -52,6 +53,15 @@ __device__ __forceinline__ uint64_t rdma_offset_of(uint64_t ptr, void* rdma_buff
|
||||
return ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr);
|
||||
}
|
||||
|
||||
// Translate a local pointer aliased into our own rdma_buffer_ptr into the
|
||||
// peer-mapped pointer for rank `peer_rank`. Only valid when the IPC fast path
|
||||
// is active (`peer_rdma_bases != nullptr`).
|
||||
__device__ __forceinline__ uint64_t peer_ptr_of(uint64_t local_ptr, void* const* peer_rdma_bases,
|
||||
void* rdma_buffer_ptr, int peer_rank) {
|
||||
const auto off = local_ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr);
|
||||
return reinterpret_cast<uint64_t>(peer_rdma_bases[peer_rank]) + off;
|
||||
}
|
||||
|
||||
// Cross-rank barrier via port-channel signal/wait ring.
|
||||
// Uses port channel `qp=0` across all connected peers.
|
||||
__device__ __forceinline__ void port_channel_barrier_block(
|
||||
@@ -66,17 +76,43 @@ __device__ __forceinline__ void port_channel_barrier_block(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Same barrier but using MemoryChannels (CUDA IPC, NVLink). Used on the
|
||||
// intra-node LL fast path.
|
||||
__device__ __forceinline__ void memory_channel_barrier_block(
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
int rank, int num_ranks) {
|
||||
const int tid = threadIdx.x;
|
||||
if (tid < num_ranks && tid != rank) {
|
||||
memory_channel_handles[tid].signal();
|
||||
memory_channel_handles[tid].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <bool kIpcPath>
|
||||
__device__ __forceinline__ void ll_barrier_block(
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
int rank, int num_ranks) {
|
||||
if constexpr (kIpcPath) {
|
||||
memory_channel_barrier_block(memory_channel_handles, rank, num_ranks);
|
||||
} else {
|
||||
port_channel_barrier_block(port_channel_handles, rank, num_ranks);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// clean_low_latency_buffer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
|
||||
template <int kNumThreads, bool kIpcPath> __launch_bounds__(kNumThreads, 1)
|
||||
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
|
||||
int64_t* clean_1, int num_clean_int_1,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
int rank, int num_ranks) {
|
||||
// Barrier before cleaning (in case of unfinished chunked EP)
|
||||
port_channel_barrier_block(port_channel_handles, rank, num_ranks);
|
||||
ll_barrier_block<kIpcPath>(port_channel_handles, memory_channel_handles, rank, num_ranks);
|
||||
|
||||
// Clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
@@ -88,27 +124,35 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
|
||||
clean_1[i] = 0;
|
||||
|
||||
// Barrier after cleaning (make sure low-latency mode work fine)
|
||||
port_channel_barrier_block(port_channel_handles, rank, num_ranks);
|
||||
ll_barrier_block<kIpcPath>(port_channel_handles, memory_channel_handles, rank, num_ranks);
|
||||
}
|
||||
|
||||
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
|
||||
int64_t* clean_1, int num_clean_int_1,
|
||||
int rank, int num_ranks,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
bool use_ipc_path,
|
||||
cudaStream_t stream) {
|
||||
constexpr int kNumThreads = 256;
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
|
||||
LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1,
|
||||
port_channel_handles, rank, num_ranks);
|
||||
if (use_ipc_path) {
|
||||
LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer<kNumThreads, true>),
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1,
|
||||
port_channel_handles, memory_channel_handles, rank, num_ranks);
|
||||
} else {
|
||||
LAUNCH_KERNEL(&cfg, (clean_low_latency_buffer<kNumThreads, false>),
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1,
|
||||
port_channel_handles, memory_channel_handles, rank, num_ranks);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// dispatch
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <bool kUseFP8, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
template <bool kUseFP8, bool kIpcPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
@@ -121,7 +165,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
int phases,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
@@ -214,14 +260,22 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
slot_idx * num_bytes_per_msg;
|
||||
if (dst_rank != rank) {
|
||||
// MSCCL++ port-channel PUT (lane 0 issues one request).
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
|
||||
.put(dst_off, src_off, num_bytes_per_msg);
|
||||
if constexpr (kIpcPath) {
|
||||
// Peer-mapped warp copy over NVLink (CUDA IPC).
|
||||
const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank);
|
||||
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||
const auto* dst_int4_ptr = reinterpret_cast<int4*>(peer_dst);
|
||||
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
// MSCCL++ port-channel PUT (lane 0 issues one request).
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(src_ptr, rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
|
||||
.put(dst_off, src_off, num_bytes_per_msg);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
__syncwarp();
|
||||
} else {
|
||||
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
@@ -279,10 +333,19 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
|
||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||
if (dst_rank != rank) {
|
||||
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
|
||||
.atomicAdd(off, static_cast<int64_t>(-num_tokens_sent - 1));
|
||||
if constexpr (kIpcPath) {
|
||||
// Single writer per (dst_expert_local_idx, rank) slot, so a
|
||||
// release-ordered store on the peer-mapped counter is enough.
|
||||
auto peer_counter = reinterpret_cast<int64_t*>(peer_ptr_of(
|
||||
reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank),
|
||||
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
|
||||
st_na_release(peer_counter, static_cast<int64_t>(-num_tokens_sent - 1));
|
||||
} else {
|
||||
auto* counter_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(counter_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[dst_expert_local_idx * num_ranks + dst_rank]
|
||||
.atomicAdd(off, static_cast<int64_t>(-num_tokens_sent - 1));
|
||||
}
|
||||
} else {
|
||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
|
||||
static_cast<int64_t>(-num_tokens_sent - 1));
|
||||
@@ -367,7 +430,10 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
void* workspace, cudaStream_t stream, int phases,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
bool use_ipc_path) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
@@ -382,19 +448,37 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden_case) { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden_case> : \
|
||||
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden_case>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
packed_recv_count, \
|
||||
rdma_recv_x, rdma_recv_count, rdma_x, \
|
||||
x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, phases, \
|
||||
rdma_buffer_ptr, port_channel_handles); } break
|
||||
if (use_ipc_path) { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden_case> \
|
||||
: dispatch<false, true, kNumWarpGroups, kNumWarpsPerGroup, hidden_case>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
packed_recv_count, \
|
||||
rdma_recv_x, rdma_recv_count, rdma_x, \
|
||||
x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, phases, \
|
||||
rdma_buffer_ptr, port_channel_handles, \
|
||||
peer_rdma_bases, memory_channel_handles); \
|
||||
} else { \
|
||||
auto dispatch_func = use_fp8 ? dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden_case> \
|
||||
: dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden_case>; \
|
||||
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
packed_recv_count, \
|
||||
rdma_recv_x, rdma_recv_count, rdma_x, \
|
||||
x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, phases, \
|
||||
rdma_buffer_ptr, port_channel_handles, \
|
||||
peer_rdma_bases, memory_channel_handles); \
|
||||
} } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||
@@ -405,7 +489,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||
// combine
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
|
||||
template <bool kIpcPath, int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
combine(void* combined_x,
|
||||
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
|
||||
@@ -418,7 +502,9 @@ combine(void* combined_x,
|
||||
int num_experts, int rank, int num_ranks,
|
||||
int phases, bool zero_copy,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
@@ -474,17 +560,25 @@ combine(void* combined_x,
|
||||
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
if (not zero_copy)
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
// MSCCL++ port-channel PUT.
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
|
||||
.put(dst_off, src_off, hidden * sizeof(nv_bfloat16));
|
||||
if constexpr (kIpcPath) {
|
||||
// Peer-mapped warp copy over NVLink. `zero_copy` is irrelevant
|
||||
// on this path because we skip the rdma_send staging buffer.
|
||||
const auto peer_dst = peer_ptr_of(dst_ptr, peer_rdma_bases, rdma_buffer_ptr, dst_rank);
|
||||
const auto peer_dst_int4 = reinterpret_cast<int4*>(peer_dst);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, peer_dst_int4, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
if (not zero_copy)
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
// MSCCL++ port-channel PUT.
|
||||
if (lane_id == 0) {
|
||||
const auto dst_off = rdma_offset_of(dst_ptr, rdma_buffer_ptr);
|
||||
const auto src_off = rdma_offset_of(static_cast<uint64_t>(buf_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
|
||||
.put(dst_off, src_off, hidden * sizeof(nv_bfloat16));
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,10 +587,17 @@ combine(void* combined_x,
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
while (ld_acquire_global(atomic_clean_flag) == 0);
|
||||
if (dst_rank != rank) {
|
||||
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
|
||||
.atomicAdd(off, static_cast<int64_t>(1));
|
||||
if constexpr (kIpcPath) {
|
||||
auto peer_flag = reinterpret_cast<int64_t*>(peer_ptr_of(
|
||||
reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx),
|
||||
peer_rdma_bases, rdma_buffer_ptr, dst_rank));
|
||||
st_na_release(peer_flag, static_cast<int64_t>(1));
|
||||
} else {
|
||||
auto* flag_ptr = rdma_recv_flag + global_expert_idx;
|
||||
const auto off = rdma_offset_of(reinterpret_cast<uint64_t>(flag_ptr), rdma_buffer_ptr);
|
||||
port_channel_handles[local_expert_idx * num_ranks + dst_rank]
|
||||
.atomicAdd(off, static_cast<int64_t>(1));
|
||||
}
|
||||
} else {
|
||||
st_na_release(rdma_recv_flag + global_expert_idx, static_cast<int64_t>(1));
|
||||
}
|
||||
@@ -561,7 +662,10 @@ void combine(void* combined_x,
|
||||
void* workspace, cudaStream_t stream,
|
||||
int phases, bool zero_copy,
|
||||
void* rdma_buffer_ptr,
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles) {
|
||||
mscclpp::PortChannelDeviceHandle* port_channel_handles,
|
||||
void* const* peer_rdma_bases,
|
||||
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
|
||||
bool use_ipc_path) {
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
constexpr int kNumMaxTopk = 9;
|
||||
@@ -574,18 +678,35 @@ void combine(void* combined_x,
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(hidden_case) { \
|
||||
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden_case, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, \
|
||||
combined_x, \
|
||||
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
|
||||
x, topk_idx, topk_weights, src_info, layout_range, \
|
||||
next_clean, num_next_clean_int, \
|
||||
atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, \
|
||||
num_max_dispatch_tokens_per_rank, \
|
||||
num_experts, rank, num_ranks, \
|
||||
phases, zero_copy, \
|
||||
rdma_buffer_ptr, port_channel_handles); } break
|
||||
if (use_ipc_path) { \
|
||||
auto combine_func = combine<true, kNumWarpGroups, kNumWarpsPerGroup, hidden_case, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, \
|
||||
combined_x, \
|
||||
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
|
||||
x, topk_idx, topk_weights, src_info, layout_range, \
|
||||
next_clean, num_next_clean_int, \
|
||||
atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, \
|
||||
num_max_dispatch_tokens_per_rank, \
|
||||
num_experts, rank, num_ranks, \
|
||||
phases, zero_copy, \
|
||||
rdma_buffer_ptr, port_channel_handles, \
|
||||
peer_rdma_bases, memory_channel_handles); \
|
||||
} else { \
|
||||
auto combine_func = combine<false, kNumWarpGroups, kNumWarpsPerGroup, hidden_case, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, \
|
||||
combined_x, \
|
||||
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
|
||||
x, topk_idx, topk_weights, src_info, layout_range, \
|
||||
next_clean, num_next_clean_int, \
|
||||
atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, \
|
||||
num_max_dispatch_tokens_per_rank, \
|
||||
num_experts, rank, num_ranks, \
|
||||
phases, zero_copy, \
|
||||
rdma_buffer_ptr, port_channel_handles, \
|
||||
peer_rdma_bases, memory_channel_handles); \
|
||||
} } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
|
||||
|
||||
Reference in New Issue
Block a user