mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
update
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <bitset>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mscclpp/env.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <mscclpp/gpu_data_types.hpp>
|
||||
#include <mscclpp/version.hpp>
|
||||
@@ -430,7 +431,7 @@ struct EndpointConfig {
|
||||
int maxWrPerSend = DefaultMaxWrPerSend, Mode mode = Mode::Default)
|
||||
: deviceIndex(deviceIndex),
|
||||
port(port),
|
||||
gidIndex(gidIndex),
|
||||
gidIndex(env()->ibGidIndex > 0 ? env()->ibGidIndex : gidIndex),
|
||||
maxCqSize(maxCqSize),
|
||||
maxCqPollNum(maxCqPollNum),
|
||||
maxSendWr(maxSendWr),
|
||||
|
||||
@@ -109,7 +109,7 @@ namespace mscclpp {
|
||||
|
||||
struct ExecutionContext {
|
||||
std::shared_ptr<ProxyService> proxyService;
|
||||
std::unordered_map<int, Connection> connections;
|
||||
std::vector<Connection> connections;
|
||||
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
|
||||
MemoryId localMemoryIdBegin = MemoryId(0);
|
||||
|
||||
@@ -121,8 +121,6 @@ struct ExecutionContext {
|
||||
// local registered memories to keep resources alive
|
||||
std::vector<mscclpp::RegisteredMemory> localRegisteredMemories;
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
|
||||
std::vector<mscclpp::SemaphoreId> proxySemaphores;
|
||||
std::vector<mscclpp::BaseMemoryChannel> memoryChannels;
|
||||
std::vector<mscclpp::BasePortChannel> portChannels;
|
||||
std::vector<mscclpp::SwitchChannel> nvlsChannels;
|
||||
@@ -266,12 +264,24 @@ struct Executor::Impl {
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers();
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
for (int peer : connectedPeers) {
|
||||
Transport transport =
|
||||
!useIB(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
|
||||
connectionFutures.push_back(this->comm->connect(transport, peer));
|
||||
std::unordered_map<int, int> peerTags;
|
||||
Transport ibTransport = IBs[rank % this->nranksPerNode];
|
||||
std::vector<std::shared_future<Connection>> connFutures;
|
||||
for (ChannelType channelType : {ChannelType::MEMORY, ChannelType::PORT}) {
|
||||
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
|
||||
for (const auto& info : channelInfos) {
|
||||
for (int peer : info.connectedPeers) {
|
||||
Transport transport = useIB(rank, peer, this->nranksPerNode) ? ibTransport : Transport::CudaIpc;
|
||||
connFutures.push_back(this->comm->connect(transport, peer, peerTags[peer]++));
|
||||
}
|
||||
}
|
||||
channelInfos = plan.impl_->getUnpairedChannelInfos(nranks, channelType);
|
||||
for (const auto& info : channelInfos) {
|
||||
for (int peer : info.connectedPeers) {
|
||||
Transport transport = useIB(rank, peer, this->nranksPerNode) ? ibTransport : Transport::CudaIpc;
|
||||
connFutures.push_back(this->comm->connect(transport, peer, peerTags[peer]++));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < connectionFutures.size(); i++) {
|
||||
context.connections[connectedPeers[i]] = connectionFutures[i].get();
|
||||
@@ -360,18 +370,15 @@ struct Executor::Impl {
|
||||
proxySemaphores.push_back(context.proxyService->addSemaphore(sem.get()));
|
||||
}
|
||||
|
||||
context.memorySemaphores = std::move(memorySemaphores);
|
||||
context.proxySemaphores = std::move(proxySemaphores);
|
||||
|
||||
for (ChannelType channelType : channelTypes) {
|
||||
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
|
||||
int index = 0;
|
||||
for (ChannelInfo& info : channelInfos) {
|
||||
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
|
||||
if (channelType == ChannelType::MEMORY) {
|
||||
context.memoryChannels.emplace_back(context.memorySemaphores[index++]);
|
||||
context.memoryChannels.emplace_back(memorySemaphores[index++]);
|
||||
} else if (channelType == ChannelType::PORT) {
|
||||
context.portChannels.emplace_back(context.proxyService->basePortChannel(context.proxySemaphores[index++]));
|
||||
context.portChannels.emplace_back(context.proxyService->basePortChannel(proxySemaphores[index++]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user