Add SwitchGroupSemaphore for O(1) multicast signal/wait

Implement signal/wait synchronization for the switch (multicast) group
using multimem.red hardware reduction on the NVSwitch, matching the
approach used by NCCL.

- Signal: multimem.red.release.sys.global.add.u32 on the multicast flag
  address atomically increments the flag on all peers via a single
  multicast reduction operation.
- Wait: polls the local device flag pointer until all devices have
  signaled (flag reaches numDevices * signalCount).

New types:
- SwitchGroupSemaphoreDeviceHandle: device-side handle with signal(),
  relaxedSignal(), wait(), and relaxedWait() methods.
- SwitchGroupSemaphore: host-side class that manages the flag channel
  and expected inbound counter.

Also adds a GroupSignalWait test that verifies all-rank GPU-side
barrier synchronization using the new semaphore.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-04-07 18:31:08 +00:00
parent fa95e82e18
commit 982b7ae230
4 changed files with 199 additions and 0 deletions

View File

@@ -25,6 +25,7 @@ struct SwitchChannel {
void* getDevicePtr();
friend class NvlsConnection;
friend struct SwitchGroupSemaphore;
};
class NvlsConnection {
@@ -60,6 +61,44 @@ class Communicator;
std::shared_ptr<NvlsConnection> connectNvlsCollective(std::shared_ptr<Communicator> comm, std::vector<int> allRanks,
size_t bufferSize);
/// A semaphore for O(1) signal/wait synchronization across all devices in a switch (multicast) group.
///
/// Uses `multimem.red` hardware reduction on the NVSwitch multicast address to signal all peers
/// simultaneously with a single operation, instead of O(N) point-to-point signals. Each device
/// atomically increments a shared flag via multicast reduction on signal, and polls its local
/// copy of the flag on wait.
///
/// The flag channel must be a @ref SwitchChannel bound to a buffer used exclusively for the flag.
/// The caller is responsible for keeping the flag channel (and its underlying @ref NvlsConnection
/// and @ref GpuBuffer) alive for the lifetime of this semaphore.
///
/// Example usage:
/// @code
/// auto flagBuffer = mscclpp::GpuBuffer<uint32_t>(1);
/// auto flagChannel = nvlsConn->bindAllocatedMemory(CUdeviceptr(flagBuffer.data()), flagBuffer.bytes());
/// auto semaphore = mscclpp::SwitchGroupSemaphore(flagChannel, numDevices);
/// auto devHandle = semaphore.deviceHandle();
/// // In kernel: devHandle.signal(); devHandle.wait();
/// @endcode
struct SwitchGroupSemaphore {
using DeviceHandle = SwitchGroupSemaphoreDeviceHandle;
/// Construct a SwitchGroupSemaphore from a SwitchChannel used as the flag channel.
/// @param flagChannel A SwitchChannel bound to a buffer used for the flag.
/// @param numDevices The number of devices in the multicast group.
SwitchGroupSemaphore(SwitchChannel& flagChannel, int numDevices);
/// Returns the device-side handle.
/// @return The device-side handle for use in GPU kernels.
DeviceHandle deviceHandle() const;
private:
void* mcFlag_;
void* deviceFlag_;
detail::UniqueGpuPtr<uint32_t> expectedInbound_;
int numDevices_;
};
} // namespace mscclpp
#endif // MSCCLPP_SWITCH_CHANNEL_HPP_

View File

@@ -15,6 +15,11 @@
#include "device.hpp"
#if defined(MSCCLPP_DEVICE_COMPILE)
#include "atomic_device.hpp"
#include "poll_device.hpp"
#endif // defined(MSCCLPP_DEVICE_COMPILE)
namespace mscclpp {
template <class>
@@ -200,6 +205,78 @@ struct SwitchChannelDeviceHandle {
#endif // defined(MSCCLPP_DEVICE_CUDA)
};
/// Device-side handle for @ref SwitchGroupSemaphore.
///
/// Provides O(1) signal/wait synchronization across all devices in a multicast group
/// using `multimem.red` hardware reduction on the NVSwitch. Each signal atomically
/// increments a flag on all peers via a single multicast reduction, and each wait
/// polls the local flag until all devices have signaled.
struct SwitchGroupSemaphoreDeviceHandle {
#if defined(MSCCLPP_DEVICE_CUDA)
/// Signal all devices in the multicast group. Ensures prior memory operations are visible.
MSCCLPP_DEVICE_INLINE void signal() {
asm volatile("multimem.red.release.sys.global.add.u32 [%0], %1;" ::"l"(mcFlag), "r"(1u) : "memory");
}
/// Relaxed signal; no prior memory completion guarantee. Use only for synchronizing execution, not data.
MSCCLPP_DEVICE_INLINE void relaxedSignal() {
asm volatile("multimem.red.relaxed.sys.global.add.u32 [%0], %1;" ::"l"(mcFlag), "r"(1u) : "memory");
}
/// Wait for all devices in the group to signal.
/// @param maxSpinCount Maximum number of spin iterations before assertion. Never asserts if negative.
MSCCLPP_DEVICE_INLINE void wait([[maybe_unused]] int64_t maxSpinCount = 100000000) {
uint32_t expected = incExpectedInbound();
POLL_MAYBE_JAILBREAK((loadInbound() < expected), maxSpinCount);
}
/// Relaxed wait; no memory completion guarantee. Use only for synchronizing execution, not data.
/// @param maxSpinCount Maximum number of spin iterations before assertion. Never asserts if negative.
MSCCLPP_DEVICE_INLINE void relaxedWait([[maybe_unused]] int64_t maxSpinCount = 100000000) {
uint32_t expected = incExpectedInbound();
POLL_MAYBE_JAILBREAK((loadInboundRelaxed() < expected), maxSpinCount);
}
/// Thread-safe read of expected inbound value.
/// @return The expected inbound value.
MSCCLPP_DEVICE_INLINE uint32_t loadExpectedInbound() {
return atomicLoad<uint32_t, scopeDevice>(expectedInbound, memoryOrderRelaxed);
}
/// Thread-safe increment of expected inbound value by @ref numDevices.
/// @return The incremented expected inbound value.
MSCCLPP_DEVICE_INLINE uint32_t incExpectedInbound() {
return atomicFetchAdd<uint32_t, scopeDevice>(expectedInbound, static_cast<uint32_t>(numDevices),
memoryOrderRelaxed) +
static_cast<uint32_t>(numDevices);
}
/// Thread-safe read of inbound flag value with acquire ordering.
/// @return The inbound flag value.
MSCCLPP_DEVICE_INLINE uint32_t loadInbound() {
return atomicLoad<uint32_t, scopeSystem>(deviceFlag, memoryOrderAcquire);
}
/// Thread-safe read of inbound flag value with relaxed ordering.
/// @return The inbound flag value.
MSCCLPP_DEVICE_INLINE uint32_t loadInboundRelaxed() {
return atomicLoad<uint32_t, scopeSystem>(deviceFlag, memoryOrderRelaxed);
}
#endif // defined(MSCCLPP_DEVICE_CUDA)
/// Multicast address for the flag (used for signaling via multimem.red).
uint32_t* mcFlag;
/// Local device address for the flag (used for polling during wait).
uint32_t* deviceFlag;
/// Local GPU memory where the expected inbound value is tracked.
uint32_t* expectedInbound;
/// Number of devices in the multicast group.
int numDevices;
};
} // namespace mscclpp
#endif // MSCCLPP_SWITCH_CHANNEL_DEVICE_HPP_

View File

@@ -233,6 +233,21 @@ SwitchChannel::DeviceHandle SwitchChannel::deviceHandle() const {
void* SwitchChannel::getDevicePtr() { return devicePtr_; };
SwitchGroupSemaphore::SwitchGroupSemaphore(SwitchChannel& flagChannel, int numDevices)
: mcFlag_(flagChannel.mcPtr_.get()),
deviceFlag_(flagChannel.devicePtr_),
expectedInbound_(detail::gpuCallocUnique<uint32_t>()),
numDevices_(numDevices) {}
SwitchGroupSemaphore::DeviceHandle SwitchGroupSemaphore::deviceHandle() const {
DeviceHandle handle;
handle.mcFlag = reinterpret_cast<uint32_t*>(mcFlag_);
handle.deviceFlag = reinterpret_cast<uint32_t*>(deviceFlag_);
handle.expectedInbound = expectedInbound_.get();
handle.numDevices = numDevices_;
return handle;
}
MSCCLPP_API_CPP std::shared_ptr<NvlsConnection> connectNvlsCollective(std::shared_ptr<Communicator> comm,
std::vector<int> allRanks, size_t bufferSize) {
auto bootstrap = comm->bootstrap();

View File

@@ -25,6 +25,7 @@ void SwitchChannelTest::TearDown() { CommunicatorTestBase::TearDown(); }
__constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan;
__constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan1;
__constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan2;
__constant__ mscclpp::SwitchGroupSemaphoreDeviceHandle gConstSwitchGroupSema;
__global__ void kernelSwitchReduce() {
#if (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900)
@@ -134,3 +135,70 @@ TEST(SwitchChannelTest, TwoChannelsSameConnection) {
ASSERT_EQ(result1, expected1);
ASSERT_EQ(result2, expected2);
}
__global__ void kernelSwitchGroupSignalWait() {
#if (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900)
// All GPUs signal and wait - acts as a barrier ensuring all data is ready
gConstSwitchGroupSema.signal();
gConstSwitchGroupSema.wait();
// After barrier, reduce and broadcast (safe because all data is visible)
auto val = gConstSwitchChan.reduce<mscclpp::f32x1>(0);
gConstSwitchChan.broadcast(0, val);
// Signal and wait again to ensure broadcast is complete before host reads
gConstSwitchGroupSema.signal();
gConstSwitchGroupSema.wait();
#endif // (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900)
}
TEST(SwitchChannelTest, GroupSignalWait) {
if (gEnv->rank >= numRanksToUse) return;
std::vector<int> ranks;
for (int i = 0; i < numRanksToUse; i++) {
ranks.push_back(i);
}
auto buffer = mscclpp::GpuBuffer<float>(256);
float data = gEnv->rank + 1.0f;
MSCCLPP_CUDATHROW(cudaMemcpy(buffer.data(), &data, sizeof(data), cudaMemcpyHostToDevice));
auto flagBuffer = mscclpp::GpuBuffer<uint32_t>(1);
size_t connSize = buffer.bytes() + flagBuffer.bytes();
auto nvlsConnection = mscclpp::connectNvlsCollective(communicator, ranks, connSize);
auto switchChannel = nvlsConnection->bindAllocatedMemory(CUdeviceptr(buffer.data()), buffer.bytes());
auto flagChannel = nvlsConnection->bindAllocatedMemory(CUdeviceptr(flagBuffer.data()), flagBuffer.bytes());
auto semaphore = mscclpp::SwitchGroupSemaphore(flagChannel, numRanksToUse);
auto deviceHandle = switchChannel.deviceHandle();
auto semaHandle = semaphore.deviceHandle();
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gConstSwitchChan, &deviceHandle, sizeof(deviceHandle)));
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gConstSwitchGroupSema, &semaHandle, sizeof(semaHandle)));
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
communicator->bootstrap()->barrier();
// All ranks launch the kernel - signal/wait replaces the host barrier for GPU synchronization
kernelSwitchGroupSignalWait<<<1, 1>>>();
MSCCLPP_CUDATHROW(cudaGetLastError());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
communicator->bootstrap()->barrier();
float result;
MSCCLPP_CUDATHROW(cudaMemcpy(&result, buffer.data(), sizeof(result), cudaMemcpyDeviceToHost));
float expected = 0.0f;
for (int i = 0; i < numRanksToUse; i++) {
expected += i + 1.0f;
}
if (result != expected) {
std::cerr << "Expected " << expected << " but got " << result << " for rank " << gEnv->rank << std::endl;
}
ASSERT_EQ(result, expected);
}