Add detection of torch.baseline and debug info

This commit is contained in:
Qinghua Zhou
2026-03-25 01:52:24 +00:00
parent 8e22010560
commit ec011f14ea
2 changed files with 75 additions and 14 deletions

View File

@@ -295,10 +295,10 @@ def main():
use_torch_baseline = (backend == "nccl")
if use_torch_baseline:
try:
# Quick test: if the NCCL shim is active it may not support all_to_all_single
tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda')
tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda')
dist.all_to_all_single(tiny_out, tiny_in)
torch.cuda.synchronize()
except Exception:
use_torch_baseline = False
if rank == 0:
@@ -387,8 +387,19 @@ def main():
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
if use_torch_baseline:
t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
try:
t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
except Exception as e:
if rank == 0:
print(f" [WARN] torch baseline failed: {e}")
print(f" [INFO] Disabling torch baseline for remaining sizes")
use_torch_baseline = False
try:
torch.cuda.synchronize()
except Exception:
pass
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
else:
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
@@ -459,12 +470,22 @@ def main():
n_warmup, n_iters = 5, 20
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
avg_bytes = total_bytes // world_size
if use_torch_baseline:
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
avg_bytes = total_bytes // world_size
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
try:
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
except Exception as e:
if rank == 0:
print(f" [WARN] torch baseline failed: {e}")
print(f" [INFO] Disabling torch baseline for remaining workloads")
use_torch_baseline = False
try:
torch.cuda.synchronize()
except Exception:
pass
print_row(fmt_size(avg_bytes), m_lat, m_bw)
else:
avg_bytes = total_bytes // world_size
print_row(fmt_size(avg_bytes), m_lat, m_bw)
else:
if rank == 0:

View File

@@ -14,6 +14,7 @@
#include <mscclpp/utils.hpp>
#include <algorithm>
#include "debug.h"
namespace mscclpp {
namespace collective {
@@ -96,20 +97,21 @@ void AlltoallvFullmesh::initialize(std::shared_ptr<Communicator> comm) {
int nRanksPerNode = comm->bootstrap()->getNranksPerNode();
int localGpuIdx = rank % nRanksPerNode;
bool isMultiNode = (worldSize_ > nRanksPerNode);
bool nvlsSupported = isNvlsSupported();
int ibDevCount = getIBDeviceCount();
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] initialize: worldSize=%d, nRanksPerNode=%d, "
"isMultiNode=%d, isNvlsSupported=%d, ibDevCount=%d, localGpuIdx=%d",
rank, worldSize_, nRanksPerNode, isMultiNode, nvlsSupported, ibDevCount, localGpuIdx);
if (!isMultiNode) {
// ── Single-node: CudaIpc for all peers ─────────────────────────────
multiNodeMode_ = MultiNodeMode::SingleNode;
this->conns_ = setupConnections(comm);
} else if (isNvlsSupported()) {
// ── GB200 NVSwitch: CudaIpc for ALL peers + staging GpuBuffers ─────
// GpuBuffer uses cuMemCreate → Fabric handles → cross-node CudaIpc works.
} else if (nvlsSupported) {
multiNodeMode_ = MultiNodeMode::NVSwitch;
this->conns_ = setupConnections(comm);
} else {
// ── IB: CudaIpc intra-node + IB inter-node ────────────────────────
// For non-NVSwitch systems (H100 etc.) where CudaIpc doesn't work cross-node.
if (getIBDeviceCount() <= 0) {
if (ibDevCount <= 0) {
throw Error("Multi-node alltoallv requires IB transport but no IB devices found. "
"Ensure IB drivers are loaded and devices are available.",
ErrorCode::InvalidUsage);
@@ -117,6 +119,15 @@ void AlltoallvFullmesh::initialize(std::shared_ptr<Communicator> comm) {
multiNodeMode_ = MultiNodeMode::IB;
this->conns_ = setupHybridConnections(comm, localGpuIdx);
}
const char* modeStr = (multiNodeMode_ == MultiNodeMode::SingleNode) ? "SingleNode" :
(multiNodeMode_ == MultiNodeMode::NVSwitch) ? "NVSwitch" : "IB";
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] mode=%s, connections=%zu",
rank, modeStr, this->conns_.size());
for (size_t i = 0; i < this->conns_.size(); ++i) {
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] conn[%zu] transport=%d",
rank, i, (int)this->conns_[i].transport());
}
}
CommResult AlltoallvFullmesh::alltoallvKernelFunc(
@@ -237,24 +248,45 @@ std::shared_ptr<void> AlltoallvFullmesh::initAlltoallvContext(
int rank = ctx->rank;
int localGpuIdx = rank % ctx->nRanksPerNode;
const char* modeStr = (ctx->mode == MultiNodeMode::SingleNode) ? "SingleNode" :
(ctx->mode == MultiNodeMode::NVSwitch) ? "NVSwitch" : "IB";
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] initContext: mode=%s, useStaging=%d, "
"input=%p (%zu B), output=%p (%zu B), localGpuIdx=%d",
rank, modeStr, ctx->useStaging, input, inputSize, output, outputSize, localGpuIdx);
if (ctx->mode == MultiNodeMode::NVSwitch) {
// ── NVSwitch (GB200): staging GpuBuffers + CudaIpc MemoryChannel for all peers
ctx->inputStaging = std::make_shared<GpuBuffer<char>>(inputSize);
ctx->outputStaging = std::make_shared<GpuBuffer<char>>(outputSize);
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch staging: input=%p (%zu B), output=%p (%zu B)",
rank, ctx->inputStaging->data(), inputSize, ctx->outputStaging->data(), outputSize);
TransportFlags allTransports = Transport::CudaIpc;
RegisteredMemory inputBufRegMem = comm->registerMemory(
ctx->inputStaging->data(), ctx->inputStaging->bytes(), allTransports);
RegisteredMemory outputBufRegMem = comm->registerMemory(
ctx->outputStaging->data(), ctx->outputStaging->bytes(), allTransports);
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: registered input=%p, output=%p",
rank, inputBufRegMem.data(), outputBufRegMem.data());
std::vector<RegisteredMemory> remoteOutputMemories = setupRemoteMemories(comm, rank, outputBufRegMem);
for (size_t i = 0; i < remoteOutputMemories.size(); ++i) {
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: remoteOutput[%zu] data=%p, size=%zu",
rank, i, remoteOutputMemories[i].data(), remoteOutputMemories[i].size());
if (remoteOutputMemories[i].data() == nullptr) {
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] ERROR: remoteOutput[%zu] has NULL data pointer! "
"Cross-node CudaIpc mapping failed.", rank, i);
}
}
constexpr int nChannelsPerConnection = 1;
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: %zu semaphores created",
rank, ctx->memorySemaphores.size());
ctx->memoryChannels = setupMemoryChannels(
this->conns_, ctx->memorySemaphores, remoteOutputMemories, inputBufRegMem, nChannelsPerConnection);
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: %zu memoryChannels created",
rank, ctx->memoryChannels.size());
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
ctx->registeredMemories = std::move(remoteOutputMemories);
@@ -268,12 +300,20 @@ std::shared_ptr<void> AlltoallvFullmesh::initAlltoallvContext(
RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, allTransports);
std::vector<RegisteredMemory> remoteOutputMemories = setupRemoteMemories(comm, rank, outputBufRegMem);
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: input=%p (%zu B), output=%p (%zu B), remotes=%zu",
rank, input, inputSize, output, outputSize, remoteOutputMemories.size());
for (size_t i = 0; i < remoteOutputMemories.size(); ++i) {
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: remoteOutput[%zu] data=%p, size=%zu",
rank, i, remoteOutputMemories[i].data(), remoteOutputMemories[i].size());
}
ctx->proxyService = std::make_shared<ProxyService>();
ctx->portChannels = setupAllPortChannels(
ctx->proxyService, *comm, this->conns_, remoteOutputMemories, inputBufRegMem);
ctx->portChannelDeviceHandles = setupPortChannelDeviceHandles(ctx->portChannels);
ctx->proxyService->startProxy(true);
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: %zu portChannels created, proxy started",
rank, ctx->portChannels.size());
ctx->registeredMemories = std::move(remoteOutputMemories);
ctx->registeredMemories.push_back(inputBufRegMem);