diff --git a/include/mscclpp/switch_channel.hpp b/include/mscclpp/switch_channel.hpp index 2dba722a..0cc55ed3 100644 --- a/include/mscclpp/switch_channel.hpp +++ b/include/mscclpp/switch_channel.hpp @@ -25,6 +25,7 @@ struct SwitchChannel { void* getDevicePtr(); friend class NvlsConnection; + friend struct SwitchGroupSemaphore; }; class NvlsConnection { @@ -60,6 +61,44 @@ class Communicator; std::shared_ptr connectNvlsCollective(std::shared_ptr comm, std::vector allRanks, size_t bufferSize); +/// A semaphore for O(1) signal/wait synchronization across all devices in a switch (multicast) group. +/// +/// Uses `multimem.red` hardware reduction on the NVSwitch multicast address to signal all peers +/// simultaneously with a single operation, instead of O(N) point-to-point signals. Each device +/// atomically increments a shared flag via multicast reduction on signal, and polls its local +/// copy of the flag on wait. +/// +/// The flag channel must be a @ref SwitchChannel bound to a buffer used exclusively for the flag. +/// The caller is responsible for keeping the flag channel (and its underlying @ref NvlsConnection +/// and @ref GpuBuffer) alive for the lifetime of this semaphore. +/// +/// Example usage: +/// @code +/// auto flagBuffer = mscclpp::GpuBuffer(1); +/// auto flagChannel = nvlsConn->bindAllocatedMemory(CUdeviceptr(flagBuffer.data()), flagBuffer.bytes()); +/// auto semaphore = mscclpp::SwitchGroupSemaphore(flagChannel, numDevices); +/// auto devHandle = semaphore.deviceHandle(); +/// // In kernel: devHandle.signal(); devHandle.wait(); +/// @endcode +struct SwitchGroupSemaphore { + using DeviceHandle = SwitchGroupSemaphoreDeviceHandle; + + /// Construct a SwitchGroupSemaphore from a SwitchChannel used as the flag channel. + /// @param flagChannel A SwitchChannel bound to a buffer used for the flag. + /// @param numDevices The number of devices in the multicast group. + SwitchGroupSemaphore(SwitchChannel& flagChannel, int numDevices); + + /// Returns the device-side handle. + /// @return The device-side handle for use in GPU kernels. + DeviceHandle deviceHandle() const; + + private: + void* mcFlag_; + void* deviceFlag_; + detail::UniqueGpuPtr expectedInbound_; + int numDevices_; +}; + } // namespace mscclpp #endif // MSCCLPP_SWITCH_CHANNEL_HPP_ diff --git a/include/mscclpp/switch_channel_device.hpp b/include/mscclpp/switch_channel_device.hpp index b52b6572..413a3dda 100644 --- a/include/mscclpp/switch_channel_device.hpp +++ b/include/mscclpp/switch_channel_device.hpp @@ -15,6 +15,11 @@ #include "device.hpp" +#if defined(MSCCLPP_DEVICE_COMPILE) +#include "atomic_device.hpp" +#include "poll_device.hpp" +#endif // defined(MSCCLPP_DEVICE_COMPILE) + namespace mscclpp { template @@ -200,6 +205,78 @@ struct SwitchChannelDeviceHandle { #endif // defined(MSCCLPP_DEVICE_CUDA) }; +/// Device-side handle for @ref SwitchGroupSemaphore. +/// +/// Provides O(1) signal/wait synchronization across all devices in a multicast group +/// using `multimem.red` hardware reduction on the NVSwitch. Each signal atomically +/// increments a flag on all peers via a single multicast reduction, and each wait +/// polls the local flag until all devices have signaled. +struct SwitchGroupSemaphoreDeviceHandle { +#if defined(MSCCLPP_DEVICE_CUDA) + /// Signal all devices in the multicast group. Ensures prior memory operations are visible. + MSCCLPP_DEVICE_INLINE void signal() { + asm volatile("multimem.red.release.sys.global.add.u32 [%0], %1;" ::"l"(mcFlag), "r"(1u) : "memory"); + } + + /// Relaxed signal; no prior memory completion guarantee. Use only for synchronizing execution, not data. + MSCCLPP_DEVICE_INLINE void relaxedSignal() { + asm volatile("multimem.red.relaxed.sys.global.add.u32 [%0], %1;" ::"l"(mcFlag), "r"(1u) : "memory"); + } + + /// Wait for all devices in the group to signal. + /// @param maxSpinCount Maximum number of spin iterations before assertion. Never asserts if negative. + MSCCLPP_DEVICE_INLINE void wait([[maybe_unused]] int64_t maxSpinCount = 100000000) { + uint32_t expected = incExpectedInbound(); + POLL_MAYBE_JAILBREAK((loadInbound() < expected), maxSpinCount); + } + + /// Relaxed wait; no memory completion guarantee. Use only for synchronizing execution, not data. + /// @param maxSpinCount Maximum number of spin iterations before assertion. Never asserts if negative. + MSCCLPP_DEVICE_INLINE void relaxedWait([[maybe_unused]] int64_t maxSpinCount = 100000000) { + uint32_t expected = incExpectedInbound(); + POLL_MAYBE_JAILBREAK((loadInboundRelaxed() < expected), maxSpinCount); + } + + /// Thread-safe read of expected inbound value. + /// @return The expected inbound value. + MSCCLPP_DEVICE_INLINE uint32_t loadExpectedInbound() { + return atomicLoad(expectedInbound, memoryOrderRelaxed); + } + + /// Thread-safe increment of expected inbound value by @ref numDevices. + /// @return The incremented expected inbound value. + MSCCLPP_DEVICE_INLINE uint32_t incExpectedInbound() { + return atomicFetchAdd(expectedInbound, static_cast(numDevices), + memoryOrderRelaxed) + + static_cast(numDevices); + } + + /// Thread-safe read of inbound flag value with acquire ordering. + /// @return The inbound flag value. + MSCCLPP_DEVICE_INLINE uint32_t loadInbound() { + return atomicLoad(deviceFlag, memoryOrderAcquire); + } + + /// Thread-safe read of inbound flag value with relaxed ordering. + /// @return The inbound flag value. + MSCCLPP_DEVICE_INLINE uint32_t loadInboundRelaxed() { + return atomicLoad(deviceFlag, memoryOrderRelaxed); + } +#endif // defined(MSCCLPP_DEVICE_CUDA) + + /// Multicast address for the flag (used for signaling via multimem.red). + uint32_t* mcFlag; + + /// Local device address for the flag (used for polling during wait). + uint32_t* deviceFlag; + + /// Local GPU memory where the expected inbound value is tracked. + uint32_t* expectedInbound; + + /// Number of devices in the multicast group. + int numDevices; +}; + } // namespace mscclpp #endif // MSCCLPP_SWITCH_CHANNEL_DEVICE_HPP_ diff --git a/src/core/switch_channel.cc b/src/core/switch_channel.cc index 981e3ca1..0d60e453 100644 --- a/src/core/switch_channel.cc +++ b/src/core/switch_channel.cc @@ -233,6 +233,21 @@ SwitchChannel::DeviceHandle SwitchChannel::deviceHandle() const { void* SwitchChannel::getDevicePtr() { return devicePtr_; }; +SwitchGroupSemaphore::SwitchGroupSemaphore(SwitchChannel& flagChannel, int numDevices) + : mcFlag_(flagChannel.mcPtr_.get()), + deviceFlag_(flagChannel.devicePtr_), + expectedInbound_(detail::gpuCallocUnique()), + numDevices_(numDevices) {} + +SwitchGroupSemaphore::DeviceHandle SwitchGroupSemaphore::deviceHandle() const { + DeviceHandle handle; + handle.mcFlag = reinterpret_cast(mcFlag_); + handle.deviceFlag = reinterpret_cast(deviceFlag_); + handle.expectedInbound = expectedInbound_.get(); + handle.numDevices = numDevices_; + return handle; +} + MSCCLPP_API_CPP std::shared_ptr connectNvlsCollective(std::shared_ptr comm, std::vector allRanks, size_t bufferSize) { auto bootstrap = comm->bootstrap(); diff --git a/test/mp_unit/switch_channel_tests.cu b/test/mp_unit/switch_channel_tests.cu index 6d913c64..3e97c40e 100644 --- a/test/mp_unit/switch_channel_tests.cu +++ b/test/mp_unit/switch_channel_tests.cu @@ -25,6 +25,7 @@ void SwitchChannelTest::TearDown() { CommunicatorTestBase::TearDown(); } __constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan; __constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan1; __constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan2; +__constant__ mscclpp::SwitchGroupSemaphoreDeviceHandle gConstSwitchGroupSema; __global__ void kernelSwitchReduce() { #if (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900) @@ -134,3 +135,70 @@ TEST(SwitchChannelTest, TwoChannelsSameConnection) { ASSERT_EQ(result1, expected1); ASSERT_EQ(result2, expected2); } + +__global__ void kernelSwitchGroupSignalWait() { +#if (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900) + // All GPUs signal and wait - acts as a barrier ensuring all data is ready + gConstSwitchGroupSema.signal(); + gConstSwitchGroupSema.wait(); + + // After barrier, reduce and broadcast (safe because all data is visible) + auto val = gConstSwitchChan.reduce(0); + gConstSwitchChan.broadcast(0, val); + + // Signal and wait again to ensure broadcast is complete before host reads + gConstSwitchGroupSema.signal(); + gConstSwitchGroupSema.wait(); +#endif // (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900) +} + +TEST(SwitchChannelTest, GroupSignalWait) { + if (gEnv->rank >= numRanksToUse) return; + + std::vector ranks; + for (int i = 0; i < numRanksToUse; i++) { + ranks.push_back(i); + } + + auto buffer = mscclpp::GpuBuffer(256); + float data = gEnv->rank + 1.0f; + MSCCLPP_CUDATHROW(cudaMemcpy(buffer.data(), &data, sizeof(data), cudaMemcpyHostToDevice)); + + auto flagBuffer = mscclpp::GpuBuffer(1); + + size_t connSize = buffer.bytes() + flagBuffer.bytes(); + auto nvlsConnection = mscclpp::connectNvlsCollective(communicator, ranks, connSize); + + auto switchChannel = nvlsConnection->bindAllocatedMemory(CUdeviceptr(buffer.data()), buffer.bytes()); + auto flagChannel = nvlsConnection->bindAllocatedMemory(CUdeviceptr(flagBuffer.data()), flagBuffer.bytes()); + + auto semaphore = mscclpp::SwitchGroupSemaphore(flagChannel, numRanksToUse); + + auto deviceHandle = switchChannel.deviceHandle(); + auto semaHandle = semaphore.deviceHandle(); + + MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gConstSwitchChan, &deviceHandle, sizeof(deviceHandle))); + MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gConstSwitchGroupSema, &semaHandle, sizeof(semaHandle))); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + communicator->bootstrap()->barrier(); + + // All ranks launch the kernel - signal/wait replaces the host barrier for GPU synchronization + kernelSwitchGroupSignalWait<<<1, 1>>>(); + MSCCLPP_CUDATHROW(cudaGetLastError()); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + communicator->bootstrap()->barrier(); + + float result; + MSCCLPP_CUDATHROW(cudaMemcpy(&result, buffer.data(), sizeof(result), cudaMemcpyDeviceToHost)); + + float expected = 0.0f; + for (int i = 0; i < numRanksToUse; i++) { + expected += i + 1.0f; + } + if (result != expected) { + std::cerr << "Expected " << expected << " but got " << result << " for rank " << gEnv->rank << std::endl; + } + ASSERT_EQ(result, expected); +}