Determine nRanksPerNode and localRank using hostname matching in mscclpp-test

This commit is contained in:
Qinghua Zhou
2026-03-05 13:33:24 +00:00
parent 82ed577a09
commit 872bc433a9
2 changed files with 25 additions and 49 deletions

View File

@@ -278,47 +278,10 @@ void AllToAllVTestEngine::allocateBuffer() {
}
void AllToAllVTestEngine::setupConnections() {
const int worldSize = args_.totalRanks;
const int rank = args_.rank;
// Register memory with CudaIpc transport for all peers.
// On NVLink-connected multi-node systems (e.g., GB200 NVL), CudaIpc works
// across nodes via NVLink. We force CudaIpc for all peers to avoid the
// default setupMeshConnections skipping non-CudaIpc connections when
// building MemoryChannels.
mscclpp::RegisteredMemory sendBufRegMem =
comm_->registerMemory(sendBuff_.get(), args_.maxBytes, mscclpp::Transport::CudaIpc);
mscclpp::RegisteredMemory recvBufRegMem =
comm_->registerMemory(recvBuff_.get(), args_.maxBytes, mscclpp::Transport::CudaIpc);
// Exchange recv buffer registration with all peers (PUT semantic: we write to peer's recv buffer)
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
for (int r = 0; r < worldSize; r++) {
if (r == rank) continue;
connectionFutures.push_back(comm_->connect(mscclpp::Transport::CudaIpc, r));
comm_->sendMemory(recvBufRegMem, r);
remoteRegMemories.push_back(comm_->recvMemory(r));
}
std::vector<mscclpp::Connection> connections;
for (auto& f : connectionFutures) {
connections.push_back(f.get());
}
// Create D2D semaphores and MemoryChannels for all peers
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
for (size_t cid = 0; cid < connections.size(); cid++) {
memorySemaphores.push_back(
std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, connections[cid]));
}
std::vector<mscclpp::MemoryChannel> memoryChannels;
for (size_t cid = 0; cid < connections.size(); cid++) {
// dst = peer's recv buffer (where we write), src = our send buffer (where we read)
memoryChannels.emplace_back(memorySemaphores[cid], remoteRegMemories[cid].get(),
sendBufRegMem, recvBuff_.get());
}
// Setup MemoryChannels: we write to peer's recv buffer from our send buffer
setupMeshConnections(memoryChannels, sendBuff_.get(), args_.maxBytes, recvBuff_.get(), args_.maxBytes,
ChannelSemantic::PUT, 1);
// Convert to device handles and copy to device memory
std::vector<DeviceHandle<mscclpp::MemoryChannel>> memoryChannelHandles;
@@ -328,10 +291,6 @@ void AllToAllVTestEngine::setupConnections() {
CUDATHROW(cudaMemcpy(d_memoryChannels, memoryChannelHandles.data(),
sizeof(DeviceHandle<mscclpp::MemoryChannel>) * memoryChannelHandles.size(),
cudaMemcpyHostToDevice));
// Keep registered memory references alive
inputMemories_.push_back(sendBufRegMem);
outputMemories_.push_back(recvBufRegMem);
}
std::vector<void*> AllToAllVTestEngine::getSendBuff() { return {sendBuff_.get()}; }

View File

@@ -650,11 +650,28 @@ void run(int argc, char* argv[]) {
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &totalRanks);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm shmcomm;
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm);
MPI_Comm_size(shmcomm, &nRanksPerNode);
MPI_Comm_rank(shmcomm, &localRank);
MPI_Comm_free(&shmcomm);
// Determine nRanksPerNode and localRank using hostname matching.
// MPI_COMM_TYPE_SHARED is unreliable on GB200 NVL systems where the shared
// memory domain can span NVLink-connected nodes across physical boundaries,
// causing wrong nRanksPerNode (e.g., 8 instead of 4) and invalid GPU ordinals.
{
constexpr int MAX_HOSTNAME = 256;
char myHost[MAX_HOSTNAME] = {};
strncpy(myHost, hostname.c_str(), MAX_HOSTNAME - 1);
std::vector<char> allHosts(totalRanks * MAX_HOSTNAME, 0);
MPI_Allgather(myHost, MAX_HOSTNAME, MPI_CHAR, allHosts.data(), MAX_HOSTNAME, MPI_CHAR, MPI_COMM_WORLD);
nRanksPerNode = 0;
localRank = 0;
for (int r = 0; r < totalRanks; r++) {
std::string rHost(&allHosts[r * MAX_HOSTNAME]);
if (rHost == hostname) {
if (r < rank) localRank++;
nRanksPerNode++;
}
}
}
isMainProc = (rank == 0) ? 1 : 0;
std::stringstream ss;