mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-21 05:19:24 +00:00
The NVLS HT B2 path introduced in3ab2e43bactivated whenever isNvlsSupported() && num_rdma_ranks > 1. On H100 NDv5 / Azure CX-7 RoCE that is true (H100 has intra-node NVLink multicast), but there is no cross-host NVSwitch fabric. mscclpp's GpuIpcMem::create then falls back to CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR whose handle exchange routes through /tmp/mscclpp_bootstrap_<pid>.sock -- a master-rank-0 unix-domain socket worker ranks cannot reach. Symptom on every commit since3ab2e43b: RuntimeError: connect() failed for unix socket to /tmp/mscclpp_bootstrap_<pid>.sock MSCCLPP_EP_FABRIC_IPC=0 was being silently ignored. src/ext/ep/buffer.cc: add resolve_fabric_ipc_supported() helper. Resolution: 1. MSCCLPP_EP_FABRIC_IPC env var (0/off/false/no => off, 1/on/true/yes/force => on, otherwise auto). 2. Auto-detect: requires both - CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED == 1 - device compute capability >= sm_100 (Blackwell+). Gate both use_fabric_ipc_alloc (RDMA buffer allocator) and nvls_ht_enabled (HT B2 multicast region) on fabric_ipc_supported. On H100 both fall back to cudaMalloc + legacy PortChannel; on GB200 NVL72 both remain enabled. Diagnostic prints now show fabric_ipc=. test/python/ext/ep/test_internode_multirank.py: replace hardcoded NUM_MAX_NVL_PEERS=4 with a runtime _detect_local_world_size() helper that reads MSCCLPP_EP_LOCAL_WORLD_SIZE / LOCAL_WORLD_SIZE / OMPI_COMM_WORLD_LOCAL_SIZE, falling back to torch.cuda.device_count(). Makes the test correct on both H100 (8 GPUs/node) and GB200 (4 GPUs/node) without code changes. src/core/atomicadd_kernel.cu: use cuCtxCreate_v4 for CUDA >= 12.5 (the underlying symbol was renamed); preserve legacy 3-arg cuCtxCreate for older toolkits. Verified on 2x H100 NDv5 at HEAD: LL intranode (8 GPUs) PASS LL internode (16 GPUs, 2 nodes) PASS HT intranode (8 GPUs) PASS HT internode (16 GPUs, 2 nodes) PASS Diagnostic on H100: [mscclpp_ep] rdma_buffer allocator: cudaMalloc (low_latency=0, nvls=1, fabric_ipc=0) [mscclpp_ep] NVLS HT multicast: disabled (low_latency=0, num_rdma_ranks=2, nvls_supported=1, fabric_ipc=0)