From 5d16ac958ed6f84f3de3f837b8191cffd22db24a Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Fri, 8 May 2026 01:42:21 +0000 Subject: [PATCH] EP GB200 (4 GPUs/node) support - configs.cuh: NUM_MAX_NVL_PEERS 8 -> 4 - internode.cu: introduce NvlPackT (uint64_t for 8 peers, uint32_t for 4) to handle packed-bool loads of is_token_in_rank; relax SourceMeta static_assert; replace 4 uint64_t-coupled sites - buffer.hpp/buffer.cc: relax NUM_MAX_NVL_PEERS assert (4 || 8); read MSCCLPP_EP_LOCAL_WORLD_SIZE env to override rdma_rank/nvl_rank partitioning when local world size != NUM_MAX_NVL_PEERS - CMakeLists.txt (ext/ep): rpath / install fix - pyproject.toml: MSCCLPP_BUILD_EXT_EP=ON - src/core/atomicadd_kernel.cu, kernels/buffer.cuh, kernels/utils.cuh: related EP fixes - test_internode_multirank.py: NUM_MAX_NVL_PEERS=4, rank %% 4 --- pyproject.toml | 1 + src/core/atomicadd_kernel.cu | 2 +- src/ext/ep/CMakeLists.txt | 4 +-- src/ext/ep/buffer.cc | 15 ++++++++-- src/ext/ep/buffer.hpp | 3 +- src/ext/ep/kernels/buffer.cuh | 1 + src/ext/ep/kernels/configs.cuh | 2 +- src/ext/ep/kernels/internode.cu | 30 ++++++++++++------- src/ext/ep/kernels/utils.cuh | 1 + .../python/ext/ep/test_internode_multirank.py | 4 +-- 10 files changed, 44 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0ea569cb..816052b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ include= ["python/mscclpp/_version.py"] [tool.scikit-build.cmake.define] MSCCLPP_BUILD_PYTHON_BINDINGS = "ON" MSCCLPP_BUILD_TESTS = "OFF" +MSCCLPP_BUILD_EXT_EP = "ON" [tool.black] line-length = 120 diff --git a/src/core/atomicadd_kernel.cu b/src/core/atomicadd_kernel.cu index 779c6c1b..0b4ade7b 100644 --- a/src/core/atomicadd_kernel.cu +++ b/src/core/atomicadd_kernel.cu @@ -37,7 +37,7 @@ void CudaIpcStream::atomicAdd(uint64_t* dst, int64_t value) { CUresult res = cuDeviceGet(&cuDevice, deviceId_); if (res != CUDA_SUCCESS) throw Error("cuDeviceGet failed", ErrorCode::InternalError); - res = cuCtxCreate(&proxyAtomicCtx_, 0, cuDevice); + res = cuCtxCreate(&proxyAtomicCtx_, NULL, 0, cuDevice); if (res != CUDA_SUCCESS) throw Error("cuCtxCreate failed", ErrorCode::InternalError); cuCtxPopCurrent(nullptr); diff --git a/src/ext/ep/CMakeLists.txt b/src/ext/ep/CMakeLists.txt index c32132c7..19f77153 100644 --- a/src/ext/ep/CMakeLists.txt +++ b/src/ext/ep/CMakeLists.txt @@ -93,7 +93,7 @@ set_target_properties(mscclpp_ep_cpp PROPERTIES CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON CXX_VISIBILITY_PRESET default - INSTALL_RPATH "\$ORIGIN/../lib" + INSTALL_RPATH "\$ORIGIN/mscclpp/lib" ) if(MSCCLPP_USE_CUDA) @@ -103,4 +103,4 @@ elseif(MSCCLPP_USE_ROCM) endif() install(TARGETS mscclpp_ep_cpp - LIBRARY DESTINATION ${INSTALL_PREFIX}/lib) + LIBRARY DESTINATION ..) diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index 35702f8c..8ba81020 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -93,8 +93,19 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ // Get ranks CUDA_CHECK(cudaGetDevice(&device_id)); - rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + // Allow overriding the local-world-size (number of GPUs per node) via the + // env var MSCCLPP_EP_LOCAL_WORLD_SIZE. By default the partitioning is + // pinned to NUM_MAX_NVL_PEERS=8, which mis-classifies all ranks as + // intra-node on hosts with fewer than 8 GPUs (e.g. GB200x4) and breaks + // cross-node LL via spurious cudaIpcOpenMemHandle on remote IPC handles. + int local_world_size = NUM_MAX_NVL_PEERS; + if (const char* env = std::getenv("MSCCLPP_EP_LOCAL_WORLD_SIZE")) { + int v = std::atoi(env); + if (v > 0 && v <= NUM_MAX_NVL_PEERS) local_world_size = v; + } + rdma_rank = rank / local_world_size, nvl_rank = rank % local_world_size; + num_rdma_ranks = std::max(1, num_ranks / local_world_size), + num_nvl_ranks = std::min(num_ranks, local_world_size); // Get device info cudaDeviceProp device_prop = {}; diff --git a/src/ext/ep/buffer.hpp b/src/ext/ep/buffer.hpp index 7c2a0540..edb0fa98 100644 --- a/src/ext/ep/buffer.hpp +++ b/src/ext/ep/buffer.hpp @@ -25,7 +25,8 @@ namespace mscclpp { namespace ep { struct Buffer { - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8 || NUM_MAX_NVL_PEERS == 4, + "The number of maximum NVLink peers must be 4 or 8"); private: // Low-latency mode buffer diff --git a/src/ext/ep/kernels/buffer.cuh b/src/ext/ep/kernels/buffer.cuh index b48bd588..e006969d 100644 --- a/src/ext/ep/kernels/buffer.cuh +++ b/src/ext/ep/kernels/buffer.cuh @@ -4,6 +4,7 @@ #include "configs.cuh" #include "exception.cuh" +#include namespace mscclpp { namespace ep { diff --git a/src/ext/ep/kernels/configs.cuh b/src/ext/ep/kernels/configs.cuh index 7f413a6b..1d8b146b 100644 --- a/src/ext/ep/kernels/configs.cuh +++ b/src/ext/ep/kernels/configs.cuh @@ -11,7 +11,7 @@ #pragma once -#define NUM_MAX_NVL_PEERS 8 +#define NUM_MAX_NVL_PEERS 4 #define NUM_MAX_RDMA_PEERS 20 #define NUM_MAX_FIFO_SLOTS 32768 #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) diff --git a/src/ext/ep/kernels/internode.cu b/src/ext/ep/kernels/internode.cu index 44a3b06d..6d867fed 100644 --- a/src/ext/ep/kernels/internode.cu +++ b/src/ext/ep/kernels/internode.cu @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #include +#include #include #include @@ -13,6 +14,17 @@ namespace mscclpp { namespace ep { +// Packed type for loading NUM_MAX_NVL_PEERS bools as a single integer. +// Supports the two configurations the rest of the kernel logic assumes: +// 8 peers -> uint64_t (8 bools fit exactly) +// 4 peers -> uint32_t (4 bools fit exactly) +// All packed loads/stores below use this alias instead of a hardcoded uint64_t. +using NvlPackT = std::conditional_t; +static_assert(NUM_MAX_NVL_PEERS == 8 || NUM_MAX_NVL_PEERS == 4, + "NUM_MAX_NVL_PEERS must be 4 or 8 for HT internode kernel"); +static_assert(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(NvlPackT), + "NvlPackT size must match NUM_MAX_NVL_PEERS bools"); + namespace internode { template @@ -137,7 +149,7 @@ void get_dispatch_layout(const int64_t* topk_idx, int* num_tokens_per_rank, int* struct SourceMeta { int src_rdma_rank, is_token_in_nvl_rank_bits; - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8 || NUM_MAX_NVL_PEERS == 4, "Invalid number of maximum NVL peers"); __forceinline__ SourceMeta() = default; @@ -389,13 +401,12 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_co // Iterate over tokens int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) { - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); - auto is_token_in_rank_uint64 = - *reinterpret_cast(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); - auto is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_uint64); + auto is_token_in_rank_packed = + *reinterpret_cast(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); + auto is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_packed); #pragma unroll for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) per_nvl_rank_count[j] += is_token_in_rank_values[j]; - total_count += (is_token_in_rank_uint64 != 0); + total_count += (is_token_in_rank_packed != 0); } // Warp reduce @@ -527,8 +538,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV // Data checks EP_DEVICE_ASSERT(num_topk <= 32); - // RDMA symmetric layout - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + // RDMA symmetric layout (packed-bool size guard is at namespace scope via NvlPackT). auto hidden_bytes = hidden_int4 * sizeof(int4); auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); auto rdma_channel_data = @@ -663,10 +673,10 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) { // Read RDMA rank existence - uint64_t is_token_in_rank_uint64 = 0; + NvlPackT is_token_in_rank_uint64 = 0; if (lane_id < kNumRDMARanks) is_token_in_rank_uint64 = - *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); + *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); // Acquire sequential lock while (lane_id == 0 and rdma_send_next_token_idx != token_idx) diff --git a/src/ext/ep/kernels/utils.cuh b/src/ext/ep/kernels/utils.cuh index 3fb01fe4..fadd2488 100644 --- a/src/ext/ep/kernels/utils.cuh +++ b/src/ext/ep/kernels/utils.cuh @@ -5,6 +5,7 @@ #include #include "exception.cuh" +#include #define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ { \ diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py index b8a1ba0b..c7a354e7 100644 --- a/test/python/ext/ep/test_internode_multirank.py +++ b/test/python/ext/ep/test_internode_multirank.py @@ -44,7 +44,7 @@ import torch.distributed as dist def init_dist(): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) - local_rank = int(os.environ.get("LOCAL_RANK", rank % 8)) + local_rank = int(os.environ.get("LOCAL_RANK", rank % 4)) torch.cuda.set_device(local_rank) dist.init_process_group( backend="nccl", world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{local_rank}") @@ -71,7 +71,7 @@ def main(): rank, num_ranks, local_rank, group = init_dist() from mscclpp.ext import ep - NUM_MAX_NVL_PEERS = 8 + NUM_MAX_NVL_PEERS = 4 assert ( num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS ), f"expected >1 node with 8 GPUs each, got num_ranks={num_ranks}"