From ec011f14ea89fa2607b4ea31e811007c41f8815a Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Wed, 25 Mar 2026 01:52:24 +0000 Subject: [PATCH] Add detection of torch.baseline and debug info --- python/test/test_alltoallv_mscclpp.py | 35 +++++++++--- .../alltoallv/alltoallv_fullmesh.cu | 54 ++++++++++++++++--- 2 files changed, 75 insertions(+), 14 deletions(-) diff --git a/python/test/test_alltoallv_mscclpp.py b/python/test/test_alltoallv_mscclpp.py index 8ff40258..95dbf044 100644 --- a/python/test/test_alltoallv_mscclpp.py +++ b/python/test/test_alltoallv_mscclpp.py @@ -295,10 +295,10 @@ def main(): use_torch_baseline = (backend == "nccl") if use_torch_baseline: try: - # Quick test: if the NCCL shim is active it may not support all_to_all_single tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda') tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda') dist.all_to_all_single(tiny_out, tiny_in) + torch.cuda.synchronize() except Exception: use_torch_baseline = False if rank == 0: @@ -387,8 +387,19 @@ def main(): m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters) if use_torch_baseline: - t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters) - print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw) + try: + t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters) + print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw) + except Exception as e: + if rank == 0: + print(f" [WARN] torch baseline failed: {e}") + print(f" [INFO] Disabling torch baseline for remaining sizes") + use_torch_baseline = False + try: + torch.cuda.synchronize() + except Exception: + pass + print_row(fmt_size(avg_msg_size), m_lat, m_bw) else: print_row(fmt_size(avg_msg_size), m_lat, m_bw) @@ -459,12 +470,22 @@ def main(): n_warmup, n_iters = 5, 20 m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) + avg_bytes = total_bytes // world_size if use_torch_baseline: - t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) - avg_bytes = total_bytes // world_size - print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw) + try: + t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) + print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw) + except Exception as e: + if rank == 0: + print(f" [WARN] torch baseline failed: {e}") + print(f" [INFO] Disabling torch baseline for remaining workloads") + use_torch_baseline = False + try: + torch.cuda.synchronize() + except Exception: + pass + print_row(fmt_size(avg_bytes), m_lat, m_bw) else: - avg_bytes = total_bytes // world_size print_row(fmt_size(avg_bytes), m_lat, m_bw) else: if rank == 0: diff --git a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu index e89f6de4..ca945361 100644 --- a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu +++ b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu @@ -14,6 +14,7 @@ #include #include +#include "debug.h" namespace mscclpp { namespace collective { @@ -96,20 +97,21 @@ void AlltoallvFullmesh::initialize(std::shared_ptr comm) { int nRanksPerNode = comm->bootstrap()->getNranksPerNode(); int localGpuIdx = rank % nRanksPerNode; bool isMultiNode = (worldSize_ > nRanksPerNode); + bool nvlsSupported = isNvlsSupported(); + int ibDevCount = getIBDeviceCount(); + + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] initialize: worldSize=%d, nRanksPerNode=%d, " + "isMultiNode=%d, isNvlsSupported=%d, ibDevCount=%d, localGpuIdx=%d", + rank, worldSize_, nRanksPerNode, isMultiNode, nvlsSupported, ibDevCount, localGpuIdx); if (!isMultiNode) { - // ── Single-node: CudaIpc for all peers ───────────────────────────── multiNodeMode_ = MultiNodeMode::SingleNode; this->conns_ = setupConnections(comm); - } else if (isNvlsSupported()) { - // ── GB200 NVSwitch: CudaIpc for ALL peers + staging GpuBuffers ───── - // GpuBuffer uses cuMemCreate → Fabric handles → cross-node CudaIpc works. + } else if (nvlsSupported) { multiNodeMode_ = MultiNodeMode::NVSwitch; this->conns_ = setupConnections(comm); } else { - // ── IB: CudaIpc intra-node + IB inter-node ──────────────────────── - // For non-NVSwitch systems (H100 etc.) where CudaIpc doesn't work cross-node. - if (getIBDeviceCount() <= 0) { + if (ibDevCount <= 0) { throw Error("Multi-node alltoallv requires IB transport but no IB devices found. " "Ensure IB drivers are loaded and devices are available.", ErrorCode::InvalidUsage); @@ -117,6 +119,15 @@ void AlltoallvFullmesh::initialize(std::shared_ptr comm) { multiNodeMode_ = MultiNodeMode::IB; this->conns_ = setupHybridConnections(comm, localGpuIdx); } + + const char* modeStr = (multiNodeMode_ == MultiNodeMode::SingleNode) ? "SingleNode" : + (multiNodeMode_ == MultiNodeMode::NVSwitch) ? "NVSwitch" : "IB"; + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] mode=%s, connections=%zu", + rank, modeStr, this->conns_.size()); + for (size_t i = 0; i < this->conns_.size(); ++i) { + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] conn[%zu] transport=%d", + rank, i, (int)this->conns_[i].transport()); + } } CommResult AlltoallvFullmesh::alltoallvKernelFunc( @@ -237,24 +248,45 @@ std::shared_ptr AlltoallvFullmesh::initAlltoallvContext( int rank = ctx->rank; int localGpuIdx = rank % ctx->nRanksPerNode; + const char* modeStr = (ctx->mode == MultiNodeMode::SingleNode) ? "SingleNode" : + (ctx->mode == MultiNodeMode::NVSwitch) ? "NVSwitch" : "IB"; + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] initContext: mode=%s, useStaging=%d, " + "input=%p (%zu B), output=%p (%zu B), localGpuIdx=%d", + rank, modeStr, ctx->useStaging, input, inputSize, output, outputSize, localGpuIdx); if (ctx->mode == MultiNodeMode::NVSwitch) { // ── NVSwitch (GB200): staging GpuBuffers + CudaIpc MemoryChannel for all peers ctx->inputStaging = std::make_shared>(inputSize); ctx->outputStaging = std::make_shared>(outputSize); + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch staging: input=%p (%zu B), output=%p (%zu B)", + rank, ctx->inputStaging->data(), inputSize, ctx->outputStaging->data(), outputSize); TransportFlags allTransports = Transport::CudaIpc; RegisteredMemory inputBufRegMem = comm->registerMemory( ctx->inputStaging->data(), ctx->inputStaging->bytes(), allTransports); RegisteredMemory outputBufRegMem = comm->registerMemory( ctx->outputStaging->data(), ctx->outputStaging->bytes(), allTransports); + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: registered input=%p, output=%p", + rank, inputBufRegMem.data(), outputBufRegMem.data()); std::vector remoteOutputMemories = setupRemoteMemories(comm, rank, outputBufRegMem); + for (size_t i = 0; i < remoteOutputMemories.size(); ++i) { + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: remoteOutput[%zu] data=%p, size=%zu", + rank, i, remoteOutputMemories[i].data(), remoteOutputMemories[i].size()); + if (remoteOutputMemories[i].data() == nullptr) { + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] ERROR: remoteOutput[%zu] has NULL data pointer! " + "Cross-node CudaIpc mapping failed.", rank, i); + } + } constexpr int nChannelsPerConnection = 1; ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection); + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: %zu semaphores created", + rank, ctx->memorySemaphores.size()); ctx->memoryChannels = setupMemoryChannels( this->conns_, ctx->memorySemaphores, remoteOutputMemories, inputBufRegMem, nChannelsPerConnection); + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: %zu memoryChannels created", + rank, ctx->memoryChannels.size()); ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels); ctx->registeredMemories = std::move(remoteOutputMemories); @@ -268,12 +300,20 @@ std::shared_ptr AlltoallvFullmesh::initAlltoallvContext( RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, allTransports); std::vector remoteOutputMemories = setupRemoteMemories(comm, rank, outputBufRegMem); + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: input=%p (%zu B), output=%p (%zu B), remotes=%zu", + rank, input, inputSize, output, outputSize, remoteOutputMemories.size()); + for (size_t i = 0; i < remoteOutputMemories.size(); ++i) { + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: remoteOutput[%zu] data=%p, size=%zu", + rank, i, remoteOutputMemories[i].data(), remoteOutputMemories[i].size()); + } ctx->proxyService = std::make_shared(); ctx->portChannels = setupAllPortChannels( ctx->proxyService, *comm, this->conns_, remoteOutputMemories, inputBufRegMem); ctx->portChannelDeviceHandles = setupPortChannelDeviceHandles(ctx->portChannels); ctx->proxyService->startProxy(true); + INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: %zu portChannels created, proxy started", + rank, ctx->portChannels.size()); ctx->registeredMemories = std::move(remoteOutputMemories); ctx->registeredMemories.push_back(inputBufRegMem);