mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Change persist MemoryChannel objects as class member to prevent dangling device pointers
This commit is contained in:
@@ -249,6 +249,11 @@ class AllToAllVTestEngine : public BaseTestEngine {
|
||||
std::shared_ptr<int> sendBuff_;
|
||||
std::shared_ptr<int> recvBuff_;
|
||||
std::shared_ptr<int[]> expectedBuff_;
|
||||
|
||||
// Must persist across setupConnections() → kernel execution so that
|
||||
// IPC memory mappings (RegisteredMemory) and semaphore tokens remain alive.
|
||||
// The device handles contain raw pointers into these mapped regions.
|
||||
std::vector<mscclpp::MemoryChannel> memoryChannels_;
|
||||
};
|
||||
|
||||
bool AllToAllVTestEngine::isInPlace() const { return false; }
|
||||
@@ -309,16 +314,18 @@ void AllToAllVTestEngine::setupConnections() {
|
||||
std::vector<mscclpp::RegisteredMemory> remoteRecvMems;
|
||||
for (auto& f : remoteMemoryFutures) remoteRecvMems.push_back(f.get());
|
||||
|
||||
// Create MemoryChannels: dst = peer's recv buf, src = our send buf
|
||||
std::vector<mscclpp::MemoryChannel> memoryChannels;
|
||||
// Create MemoryChannels: dst = peer's recv buf, src = our send buf.
|
||||
// Store in class member so IPC mappings + semaphore tokens stay alive
|
||||
// until the engine is destroyed (device handles hold raw pointers).
|
||||
memoryChannels_.clear();
|
||||
for (size_t i = 0; i < connections.size(); i++) {
|
||||
auto semaphore = std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, connections[i]);
|
||||
memoryChannels.emplace_back(semaphore, remoteRecvMems[i], sendBufRegMem);
|
||||
memoryChannels_.emplace_back(semaphore, remoteRecvMems[i], sendBufRegMem);
|
||||
}
|
||||
|
||||
// Convert to device handles and copy to device memory
|
||||
std::vector<DeviceHandle<mscclpp::MemoryChannel>> memoryChannelHandles;
|
||||
for (auto& channel : memoryChannels) {
|
||||
for (auto& channel : memoryChannels_) {
|
||||
memoryChannelHandles.push_back(mscclpp::deviceHandle(channel));
|
||||
}
|
||||
CUDATHROW(cudaMemcpy(d_memoryChannels, memoryChannelHandles.data(),
|
||||
|
||||
Reference in New Issue
Block a user