mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-07-03 05:47:02 +00:00
Merge branch 'main' into rjsouza/nvls-allgather-pr
This commit is contained in:
@@ -95,6 +95,7 @@ struct hash<mscclpp::DeviceExecutionPlanKey> {
|
||||
namespace {
|
||||
auto hasIBDevices = []() { return mscclpp::getIBDeviceCount() > 0; };
|
||||
|
||||
// TODO(binyli): Need to add NVL domain check.
|
||||
auto useIB = [](int rank1, int rank2, int nranksPerNode) {
|
||||
if (mscclpp::env()->forceDisableIb) return false;
|
||||
bool inSameNode = rank1 / nranksPerNode == rank2 / nranksPerNode;
|
||||
@@ -110,7 +111,7 @@ namespace mscclpp {
|
||||
|
||||
struct ExecutionContext {
|
||||
std::shared_ptr<ProxyService> proxyService;
|
||||
std::unordered_map<int, Connection> connections;
|
||||
std::vector<Connection> connections;
|
||||
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
|
||||
MemoryId localMemoryIdBegin = MemoryId(0);
|
||||
|
||||
@@ -122,8 +123,6 @@ struct ExecutionContext {
|
||||
// local registered memories to keep resources alive
|
||||
std::vector<mscclpp::RegisteredMemory> localRegisteredMemories;
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
|
||||
std::vector<mscclpp::SemaphoreId> proxySemaphores;
|
||||
std::vector<mscclpp::BaseMemoryChannel> memoryChannels;
|
||||
std::vector<mscclpp::BasePortChannel> portChannels;
|
||||
std::vector<mscclpp::SwitchChannel> nvlsChannels;
|
||||
@@ -267,15 +266,36 @@ struct Executor::Impl {
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers();
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
for (int peer : connectedPeers) {
|
||||
Transport transport =
|
||||
!useIB(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
|
||||
connectionFutures.push_back(this->comm->connect(transport, peer));
|
||||
// Create one connection (unique QP) per channel entry. Each channel gets its own
|
||||
// QP — no shared connections.
|
||||
// Use per-peer tag counters so that matched connections between pairs of ranks use
|
||||
// the same tag, regardless of the order peers appear in each rank's connected_to list.
|
||||
std::unordered_map<int, int> peerTagCounters;
|
||||
Transport ibTransport = IBs[rank % this->nranksPerNode];
|
||||
std::vector<std::shared_future<Connection>> connFutures;
|
||||
for (ChannelType channelType : {ChannelType::MEMORY, ChannelType::PORT}) {
|
||||
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
|
||||
for (const auto& info : channelInfos) {
|
||||
for (int peer : info.connectedPeers) {
|
||||
Transport transport = channelType == ChannelType::PORT && useIB(rank, peer, this->nranksPerNode)
|
||||
? ibTransport
|
||||
: Transport::CudaIpc;
|
||||
connFutures.push_back(this->comm->connect(transport, peer, peerTagCounters[peer]++));
|
||||
}
|
||||
}
|
||||
channelInfos = plan.impl_->getUnpairedChannelInfos(nranks, channelType);
|
||||
for (const auto& info : channelInfos) {
|
||||
for (int peer : info.connectedPeers) {
|
||||
Transport transport = channelType == ChannelType::PORT && useIB(rank, peer, this->nranksPerNode)
|
||||
? ibTransport
|
||||
: Transport::CudaIpc;
|
||||
connFutures.push_back(this->comm->connect(transport, peer, peerTagCounters[peer]++));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < connectionFutures.size(); i++) {
|
||||
context.connections[connectedPeers[i]] = connectionFutures[i].get();
|
||||
|
||||
for (auto& future : connFutures) {
|
||||
context.connections.push_back(future.get());
|
||||
}
|
||||
|
||||
std::vector<NvlsInfo> nvlsInfos = plan.impl_->nvlsInfos.at(rank);
|
||||
@@ -329,10 +349,11 @@ struct Executor::Impl {
|
||||
std::vector<std::shared_future<Semaphore>> futureProxySemaphores;
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
|
||||
std::vector<mscclpp::SemaphoreId> proxySemaphores;
|
||||
int connIdx = 0;
|
||||
auto processChannelInfos = [&](std::vector<ChannelInfo>& channelInfos) {
|
||||
for (ChannelInfo& info : channelInfos) {
|
||||
for (int peer : info.connectedPeers) {
|
||||
auto connection = context.connections.at(peer);
|
||||
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
|
||||
auto& connection = context.connections[connIdx++];
|
||||
if (info.channelType == ChannelType::MEMORY) {
|
||||
futureMemorySemaphores.push_back(this->comm->buildSemaphore(
|
||||
connection, this->comm->remoteRankOf(connection), this->comm->tagOf(connection)));
|
||||
@@ -361,18 +382,15 @@ struct Executor::Impl {
|
||||
proxySemaphores.push_back(context.proxyService->addSemaphore(sem.get()));
|
||||
}
|
||||
|
||||
context.memorySemaphores = std::move(memorySemaphores);
|
||||
context.proxySemaphores = std::move(proxySemaphores);
|
||||
|
||||
for (ChannelType channelType : channelTypes) {
|
||||
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
|
||||
int index = 0;
|
||||
for (ChannelInfo& info : channelInfos) {
|
||||
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
|
||||
if (channelType == ChannelType::MEMORY) {
|
||||
context.memoryChannels.emplace_back(context.memorySemaphores[index++]);
|
||||
context.memoryChannels.emplace_back(memorySemaphores[index++]);
|
||||
} else if (channelType == ChannelType::PORT) {
|
||||
context.portChannels.emplace_back(context.proxyService->basePortChannel(context.proxySemaphores[index++]));
|
||||
context.portChannels.emplace_back(context.proxyService->basePortChannel(proxySemaphores[index++]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,11 +174,11 @@ MSCCLPP_DEVICE_INLINE void handlePut(const Operation& op, void* input, void* out
|
||||
uint32_t dstOffset =
|
||||
dstOffsets[tid] + getOffset<ReuseScratch>(portChannelBufferTypes_[op.outputBufferRefs[tid].id], offset);
|
||||
uint32_t srcOffset = srcOffsets[tid] + getOffset<ReuseScratch>(op.inputBufferRefs[tid].type, offset);
|
||||
if constexpr (PutWithSignal) {
|
||||
portChannels_[channelIndexes[tid]].putWithSignal(dstMemoryId, dstOffset, srcMemoryId, srcOffset, size);
|
||||
} else if constexpr (PutWithSignalAndFlush) {
|
||||
if constexpr (PutWithSignalAndFlush) {
|
||||
portChannels_[channelIndexes[tid]].putWithSignalAndFlush(dstMemoryId, (uint64_t)dstOffset, srcMemoryId,
|
||||
(uint64_t)srcOffsets, size);
|
||||
(uint64_t)srcOffset, size);
|
||||
} else if constexpr (PutWithSignal) {
|
||||
portChannels_[channelIndexes[tid]].putWithSignal(dstMemoryId, dstOffset, srcMemoryId, srcOffset, size);
|
||||
} else {
|
||||
portChannels_[channelIndexes[tid]].put(dstMemoryId, dstOffset, srcMemoryId, srcOffset, size);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user