mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Add SwitchGroupSemaphore for O(1) multicast signal/wait
Implement signal/wait synchronization for the switch (multicast) group using multimem.red hardware reduction on the NVSwitch, matching the approach used by NCCL. - Signal: multimem.red.release.sys.global.add.u32 on the multicast flag address atomically increments the flag on all peers via a single multicast reduction operation. - Wait: polls the local device flag pointer until all devices have signaled (flag reaches numDevices * signalCount). New types: - SwitchGroupSemaphoreDeviceHandle: device-side handle with signal(), relaxedSignal(), wait(), and relaxedWait() methods. - SwitchGroupSemaphore: host-side class that manages the flag channel and expected inbound counter. Also adds a GroupSignalWait test that verifies all-rank GPU-side barrier synchronization using the new semaphore. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -25,6 +25,7 @@ struct SwitchChannel {
|
||||
void* getDevicePtr();
|
||||
|
||||
friend class NvlsConnection;
|
||||
friend struct SwitchGroupSemaphore;
|
||||
};
|
||||
|
||||
class NvlsConnection {
|
||||
@@ -60,6 +61,44 @@ class Communicator;
|
||||
std::shared_ptr<NvlsConnection> connectNvlsCollective(std::shared_ptr<Communicator> comm, std::vector<int> allRanks,
|
||||
size_t bufferSize);
|
||||
|
||||
/// A semaphore for O(1) signal/wait synchronization across all devices in a switch (multicast) group.
|
||||
///
|
||||
/// Uses `multimem.red` hardware reduction on the NVSwitch multicast address to signal all peers
|
||||
/// simultaneously with a single operation, instead of O(N) point-to-point signals. Each device
|
||||
/// atomically increments a shared flag via multicast reduction on signal, and polls its local
|
||||
/// copy of the flag on wait.
|
||||
///
|
||||
/// The flag channel must be a @ref SwitchChannel bound to a buffer used exclusively for the flag.
|
||||
/// The caller is responsible for keeping the flag channel (and its underlying @ref NvlsConnection
|
||||
/// and @ref GpuBuffer) alive for the lifetime of this semaphore.
|
||||
///
|
||||
/// Example usage:
|
||||
/// @code
|
||||
/// auto flagBuffer = mscclpp::GpuBuffer<uint32_t>(1);
|
||||
/// auto flagChannel = nvlsConn->bindAllocatedMemory(CUdeviceptr(flagBuffer.data()), flagBuffer.bytes());
|
||||
/// auto semaphore = mscclpp::SwitchGroupSemaphore(flagChannel, numDevices);
|
||||
/// auto devHandle = semaphore.deviceHandle();
|
||||
/// // In kernel: devHandle.signal(); devHandle.wait();
|
||||
/// @endcode
|
||||
struct SwitchGroupSemaphore {
|
||||
using DeviceHandle = SwitchGroupSemaphoreDeviceHandle;
|
||||
|
||||
/// Construct a SwitchGroupSemaphore from a SwitchChannel used as the flag channel.
|
||||
/// @param flagChannel A SwitchChannel bound to a buffer used for the flag.
|
||||
/// @param numDevices The number of devices in the multicast group.
|
||||
SwitchGroupSemaphore(SwitchChannel& flagChannel, int numDevices);
|
||||
|
||||
/// Returns the device-side handle.
|
||||
/// @return The device-side handle for use in GPU kernels.
|
||||
DeviceHandle deviceHandle() const;
|
||||
|
||||
private:
|
||||
void* mcFlag_;
|
||||
void* deviceFlag_;
|
||||
detail::UniqueGpuPtr<uint32_t> expectedInbound_;
|
||||
int numDevices_;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_SWITCH_CHANNEL_HPP_
|
||||
|
||||
@@ -15,6 +15,11 @@
|
||||
|
||||
#include "device.hpp"
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
#include "atomic_device.hpp"
|
||||
#include "poll_device.hpp"
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
template <class>
|
||||
@@ -200,6 +205,78 @@ struct SwitchChannelDeviceHandle {
|
||||
#endif // defined(MSCCLPP_DEVICE_CUDA)
|
||||
};
|
||||
|
||||
/// Device-side handle for @ref SwitchGroupSemaphore.
|
||||
///
|
||||
/// Provides O(1) signal/wait synchronization across all devices in a multicast group
|
||||
/// using `multimem.red` hardware reduction on the NVSwitch. Each signal atomically
|
||||
/// increments a flag on all peers via a single multicast reduction, and each wait
|
||||
/// polls the local flag until all devices have signaled.
|
||||
struct SwitchGroupSemaphoreDeviceHandle {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
/// Signal all devices in the multicast group. Ensures prior memory operations are visible.
|
||||
MSCCLPP_DEVICE_INLINE void signal() {
|
||||
asm volatile("multimem.red.release.sys.global.add.u32 [%0], %1;" ::"l"(mcFlag), "r"(1u) : "memory");
|
||||
}
|
||||
|
||||
/// Relaxed signal; no prior memory completion guarantee. Use only for synchronizing execution, not data.
|
||||
MSCCLPP_DEVICE_INLINE void relaxedSignal() {
|
||||
asm volatile("multimem.red.relaxed.sys.global.add.u32 [%0], %1;" ::"l"(mcFlag), "r"(1u) : "memory");
|
||||
}
|
||||
|
||||
/// Wait for all devices in the group to signal.
|
||||
/// @param maxSpinCount Maximum number of spin iterations before assertion. Never asserts if negative.
|
||||
MSCCLPP_DEVICE_INLINE void wait([[maybe_unused]] int64_t maxSpinCount = 100000000) {
|
||||
uint32_t expected = incExpectedInbound();
|
||||
POLL_MAYBE_JAILBREAK((loadInbound() < expected), maxSpinCount);
|
||||
}
|
||||
|
||||
/// Relaxed wait; no memory completion guarantee. Use only for synchronizing execution, not data.
|
||||
/// @param maxSpinCount Maximum number of spin iterations before assertion. Never asserts if negative.
|
||||
MSCCLPP_DEVICE_INLINE void relaxedWait([[maybe_unused]] int64_t maxSpinCount = 100000000) {
|
||||
uint32_t expected = incExpectedInbound();
|
||||
POLL_MAYBE_JAILBREAK((loadInboundRelaxed() < expected), maxSpinCount);
|
||||
}
|
||||
|
||||
/// Thread-safe read of expected inbound value.
|
||||
/// @return The expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint32_t loadExpectedInbound() {
|
||||
return atomicLoad<uint32_t, scopeDevice>(expectedInbound, memoryOrderRelaxed);
|
||||
}
|
||||
|
||||
/// Thread-safe increment of expected inbound value by @ref numDevices.
|
||||
/// @return The incremented expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint32_t incExpectedInbound() {
|
||||
return atomicFetchAdd<uint32_t, scopeDevice>(expectedInbound, static_cast<uint32_t>(numDevices),
|
||||
memoryOrderRelaxed) +
|
||||
static_cast<uint32_t>(numDevices);
|
||||
}
|
||||
|
||||
/// Thread-safe read of inbound flag value with acquire ordering.
|
||||
/// @return The inbound flag value.
|
||||
MSCCLPP_DEVICE_INLINE uint32_t loadInbound() {
|
||||
return atomicLoad<uint32_t, scopeSystem>(deviceFlag, memoryOrderAcquire);
|
||||
}
|
||||
|
||||
/// Thread-safe read of inbound flag value with relaxed ordering.
|
||||
/// @return The inbound flag value.
|
||||
MSCCLPP_DEVICE_INLINE uint32_t loadInboundRelaxed() {
|
||||
return atomicLoad<uint32_t, scopeSystem>(deviceFlag, memoryOrderRelaxed);
|
||||
}
|
||||
#endif // defined(MSCCLPP_DEVICE_CUDA)
|
||||
|
||||
/// Multicast address for the flag (used for signaling via multimem.red).
|
||||
uint32_t* mcFlag;
|
||||
|
||||
/// Local device address for the flag (used for polling during wait).
|
||||
uint32_t* deviceFlag;
|
||||
|
||||
/// Local GPU memory where the expected inbound value is tracked.
|
||||
uint32_t* expectedInbound;
|
||||
|
||||
/// Number of devices in the multicast group.
|
||||
int numDevices;
|
||||
};
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_SWITCH_CHANNEL_DEVICE_HPP_
|
||||
|
||||
@@ -233,6 +233,21 @@ SwitchChannel::DeviceHandle SwitchChannel::deviceHandle() const {
|
||||
|
||||
void* SwitchChannel::getDevicePtr() { return devicePtr_; };
|
||||
|
||||
SwitchGroupSemaphore::SwitchGroupSemaphore(SwitchChannel& flagChannel, int numDevices)
|
||||
: mcFlag_(flagChannel.mcPtr_.get()),
|
||||
deviceFlag_(flagChannel.devicePtr_),
|
||||
expectedInbound_(detail::gpuCallocUnique<uint32_t>()),
|
||||
numDevices_(numDevices) {}
|
||||
|
||||
SwitchGroupSemaphore::DeviceHandle SwitchGroupSemaphore::deviceHandle() const {
|
||||
DeviceHandle handle;
|
||||
handle.mcFlag = reinterpret_cast<uint32_t*>(mcFlag_);
|
||||
handle.deviceFlag = reinterpret_cast<uint32_t*>(deviceFlag_);
|
||||
handle.expectedInbound = expectedInbound_.get();
|
||||
handle.numDevices = numDevices_;
|
||||
return handle;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<NvlsConnection> connectNvlsCollective(std::shared_ptr<Communicator> comm,
|
||||
std::vector<int> allRanks, size_t bufferSize) {
|
||||
auto bootstrap = comm->bootstrap();
|
||||
|
||||
@@ -25,6 +25,7 @@ void SwitchChannelTest::TearDown() { CommunicatorTestBase::TearDown(); }
|
||||
__constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan;
|
||||
__constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan1;
|
||||
__constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan2;
|
||||
__constant__ mscclpp::SwitchGroupSemaphoreDeviceHandle gConstSwitchGroupSema;
|
||||
|
||||
__global__ void kernelSwitchReduce() {
|
||||
#if (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900)
|
||||
@@ -134,3 +135,70 @@ TEST(SwitchChannelTest, TwoChannelsSameConnection) {
|
||||
ASSERT_EQ(result1, expected1);
|
||||
ASSERT_EQ(result2, expected2);
|
||||
}
|
||||
|
||||
__global__ void kernelSwitchGroupSignalWait() {
|
||||
#if (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900)
|
||||
// All GPUs signal and wait - acts as a barrier ensuring all data is ready
|
||||
gConstSwitchGroupSema.signal();
|
||||
gConstSwitchGroupSema.wait();
|
||||
|
||||
// After barrier, reduce and broadcast (safe because all data is visible)
|
||||
auto val = gConstSwitchChan.reduce<mscclpp::f32x1>(0);
|
||||
gConstSwitchChan.broadcast(0, val);
|
||||
|
||||
// Signal and wait again to ensure broadcast is complete before host reads
|
||||
gConstSwitchGroupSema.signal();
|
||||
gConstSwitchGroupSema.wait();
|
||||
#endif // (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900)
|
||||
}
|
||||
|
||||
TEST(SwitchChannelTest, GroupSignalWait) {
|
||||
if (gEnv->rank >= numRanksToUse) return;
|
||||
|
||||
std::vector<int> ranks;
|
||||
for (int i = 0; i < numRanksToUse; i++) {
|
||||
ranks.push_back(i);
|
||||
}
|
||||
|
||||
auto buffer = mscclpp::GpuBuffer<float>(256);
|
||||
float data = gEnv->rank + 1.0f;
|
||||
MSCCLPP_CUDATHROW(cudaMemcpy(buffer.data(), &data, sizeof(data), cudaMemcpyHostToDevice));
|
||||
|
||||
auto flagBuffer = mscclpp::GpuBuffer<uint32_t>(1);
|
||||
|
||||
size_t connSize = buffer.bytes() + flagBuffer.bytes();
|
||||
auto nvlsConnection = mscclpp::connectNvlsCollective(communicator, ranks, connSize);
|
||||
|
||||
auto switchChannel = nvlsConnection->bindAllocatedMemory(CUdeviceptr(buffer.data()), buffer.bytes());
|
||||
auto flagChannel = nvlsConnection->bindAllocatedMemory(CUdeviceptr(flagBuffer.data()), flagBuffer.bytes());
|
||||
|
||||
auto semaphore = mscclpp::SwitchGroupSemaphore(flagChannel, numRanksToUse);
|
||||
|
||||
auto deviceHandle = switchChannel.deviceHandle();
|
||||
auto semaHandle = semaphore.deviceHandle();
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gConstSwitchChan, &deviceHandle, sizeof(deviceHandle)));
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gConstSwitchGroupSema, &semaHandle, sizeof(semaHandle)));
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
// All ranks launch the kernel - signal/wait replaces the host barrier for GPU synchronization
|
||||
kernelSwitchGroupSignalWait<<<1, 1>>>();
|
||||
MSCCLPP_CUDATHROW(cudaGetLastError());
|
||||
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
float result;
|
||||
MSCCLPP_CUDATHROW(cudaMemcpy(&result, buffer.data(), sizeof(result), cudaMemcpyDeviceToHost));
|
||||
|
||||
float expected = 0.0f;
|
||||
for (int i = 0; i < numRanksToUse; i++) {
|
||||
expected += i + 1.0f;
|
||||
}
|
||||
if (result != expected) {
|
||||
std::cerr << "Expected " << expected << " but got " << result << " for rank " << gEnv->rank << std::endl;
|
||||
}
|
||||
ASSERT_EQ(result, expected);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user