mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
WIP
This commit is contained in:
@@ -102,7 +102,6 @@ void NativeAlgorithm::reset(int32_t contextKey) {
|
||||
} else {
|
||||
contexts_.clear();
|
||||
customKeyMap_.clear();
|
||||
initialized_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -78,6 +78,8 @@ struct AllreduceNvlsPacketAdapter {
|
||||
void AllreduceNvlsPacket::initialize(std::shared_ptr<Communicator> comm) {
|
||||
int nSwitchChannels = 1;
|
||||
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
|
||||
this->switchChannels_ =
|
||||
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
|
||||
}
|
||||
|
||||
AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
|
||||
@@ -92,9 +94,7 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
// setup channels
|
||||
int nSwitchChannels = 1;
|
||||
ctx->switchChannels =
|
||||
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
|
||||
ctx->switchChannels = this->switchChannels_;
|
||||
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder {
|
||||
uintptr_t flagBuffer_;
|
||||
size_t flagBufferSize_;
|
||||
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
|
||||
std::vector<SwitchChannel> switchChannels_;
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
|
||||
Reference in New Issue
Block a user