FIFO improvements (#557)

* Revert `MSCCLPP_FIFO_USE_TAIL_REPLICA=1` back to the default.
* Optimize `FifoDeviceHandle`.
* Do not use `cudaHostAllocWriteCombined` that increases latency.
* Pin host memory for `Host2DeviceSemaphore::outboundSemaphore_`.
* Fix proxy NUMA binding issues.
* Prevent graph capture inside proxy threads.
* Now `CudaIpcConnection` skips stream sync when unnecessary.
* Now any type of connection needs to hold a shared pointer to the
context for memory safety.
* Now a context should be always managed by a shared pointer for memory
safety.
* Minor docs & interface improvements.
* Minor fix in `mscclpp-test` correctness test.
This commit is contained in:
Changho Hwang
2025-06-24 09:50:28 -07:00
committed by GitHub
parent 2796cfa5ba
commit b4dde38db8
28 changed files with 384 additions and 353 deletions

View File

@@ -127,8 +127,8 @@ class TcpBootstrap : public Bootstrap {
/// @return The unique ID stored in the TcpBootstrap.
UniqueId getUniqueId() const;
/// Initialize the TcpBootstrap with a given unique ID. The unique ID can be generated by any methods;
/// it can be created by createUniqueId() or can be any arbitrary bit arrays provided by the user.
/// Initialize the TcpBootstrap with a given unique ID. The unique ID can be generated by any method;
/// it can be created by createUniqueId() or can be any arbitrary bit array provided by the user.
/// @param uniqueId The unique ID to initialize the TcpBootstrap with.
/// @param timeoutSec The connection timeout in seconds.
void initialize(UniqueId uniqueId, int64_t timeoutSec = 30);
@@ -453,7 +453,7 @@ class Endpoint {
/// @return A vector of characters representing the serialized Endpoint object.
std::vector<char> serialize();
/// Deserialize a Endpoint object from a vector of characters.
/// Deserialize an Endpoint object from a vector of characters.
///
/// @param data A vector of characters representing a serialized Endpoint object.
/// @return A deserialized Endpoint object.
@@ -473,8 +473,10 @@ class Connection {
public:
/// Constructor.
/// @param maxWriteQueueSize The maximum number of write requests that can be queued.
Connection(int maxWriteQueueSize) : maxWriteQueueSize(maxWriteQueueSize){};
Connection(std::shared_ptr<Context> context, int maxWriteQueueSize)
: context_(context), maxWriteQueueSize_(maxWriteQueueSize){};
/// Destructor.
virtual ~Connection() = default;
/// Write data from a source RegisteredMemory to a destination RegisteredMemory.
@@ -487,7 +489,7 @@ class Connection {
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) = 0;
/// Update a 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process.
/// Update an 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process.
///
/// @param dst The destination RegisteredMemory.
/// @param dstOffset The offset in bytes from the start of the destination RegisteredMemory.
@@ -522,7 +524,9 @@ class Connection {
// Internal methods for getting implementation pointers.
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
int maxWriteQueueSize;
std::shared_ptr<Context> context_;
int maxWriteQueueSize_;
};
/// Used to configure an endpoint.
@@ -567,19 +571,19 @@ struct EndpointConfig {
/// 1. The client creates an endpoint with createEndpoint() and sends it to the server.
/// 2. The server receives the client endpoint, creates its own endpoint with createEndpoint(), sends it to the
/// client, and creates a connection with connect().
/// 4. The client receives the server endpoint, creates a connection with connect() and sends a
/// 3. The client receives the server endpoint, creates a connection with connect() and sends a
/// RegisteredMemory to the server.
/// 5. The server receives the RegisteredMemory and writes to it using the previously created connection.
/// The client waiting to create a connection before sending the RegisteredMemory ensures that the server can not
/// 4. The server receives the RegisteredMemory and writes to it using the previously created connection.
/// The client waiting to create a connection before sending the RegisteredMemory ensures that the server cannot
/// write to the RegisteredMemory before the connection is established.
///
/// While some transports may have more relaxed implementation behavior, this should not be relied upon.
class Context {
class Context : public std::enable_shared_from_this<Context> {
public:
/// Create a context.
Context();
/// Create a new Context instance.
static std::shared_ptr<Context> create() { return std::shared_ptr<Context>(new Context()); }
/// Destroy the context.
/// Destructor.
~Context();
/// Register a region of GPU memory for use in this context.
@@ -606,6 +610,8 @@ class Context {
std::shared_ptr<Connection> connect(Endpoint localEndpoint, Endpoint remoteEndpoint);
private:
Context();
struct Impl;
std::unique_ptr<Impl> pimpl_;
@@ -620,7 +626,7 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
/// A class that sets up all registered memories and connections between processes.
///
/// A typical way to use this class:
/// 1. Call connect() to declare connections between the calling process with other processes.
/// 1. Call connect() to declare connections between the calling process and other processes.
/// 2. Call registerMemory() to register memory regions that will be used for communication.
/// 3. Call sendMemory() or recvMemory() to send/receive registered memory regions to/from
/// other processes.
@@ -670,7 +676,7 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc); // undefined behavior
/// communicator.sendMemory(memory1, 0, tag);
/// ```
/// In the wrong example, the connection information from rank 1 will be sent to `mem1` object on rank 0,
/// In the wrong example, the connection information from rank 1 will be sent to the `mem1` object on rank 0,
/// where the object type is RegisteredMemory, not Connection.
///
class Communicator {
@@ -762,7 +768,7 @@ class Communicator {
/// the first get() on the future.
/// Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
/// to have a counterpart from rank `j` to rank `i`. Note that with IB, buffers are registered at a page level and if
/// a buffer is spread through multiple pages and do not fully utilize all of them, IB's QP has to register for all
/// a buffer is spread through multiple pages and does not fully utilize all of them, IB's QP has to register for all
/// involved pages. This potentially has security risks if the connection's accesses are given to a malicious process.
///
/// Multiple calls to either sendMemory() or connect() with the same @p remoteRank and @p tag will be ordered by
@@ -818,11 +824,11 @@ extern const TransportFlags AllIBTransports;
/// A constant TransportFlags object representing all transports.
extern const TransportFlags AllTransports;
/// A type which could be safely used in device side.
/// A type which could be safely used on the device side.
template <class T>
using DeviceHandle = typename T::DeviceHandle;
/// Retrieve the deviceHandle instance from host object.
/// Retrieve the deviceHandle instance from a host object.
template <typename T>
DeviceHandle<std::remove_reference_t<T>> deviceHandle(T&& t) {
return t.deviceHandle();

View File

@@ -93,7 +93,7 @@ class Env {
/// Env name: `MSCCLPP_FIFO_USE_TAIL_REPLICA`. If set to true, it will replicate the FIFO tail on the GPU memory,
/// which makes the GPU poll on the tail faster, but requires a periodic FIFO flush to update the replica on the GPU.
/// If set to false, the GPU will directly read the tail from the host memory, which is slower but does not require
/// periodic flushes. Default is false.
/// periodic flushes. Default is true.
const bool fifoUseTailReplica;
private:

View File

@@ -4,51 +4,46 @@
#ifndef MSCCLPP_FIFO_HPP_
#define MSCCLPP_FIFO_HPP_
#include <cstdint>
#include <functional>
#include <memory>
#include "fifo_device.hpp"
namespace mscclpp {
constexpr size_t DEFAULT_FIFO_SIZE = 128;
constexpr size_t DEFAULT_FIFO_SIZE = 512;
/// A class representing a host proxy FIFO that can consume work elements pushed by device threads.
/// Host-side proxy FIFO for device-produced work elements.
class Fifo {
public:
/// Constructs a new Fifo object.
/// @param size The number of entires in the FIFO.
/// Constructor.
/// @param size Number of entries (default: DEFAULT_FIFO_SIZE).
Fifo(int size = DEFAULT_FIFO_SIZE);
/// Destroys the Fifo object.
/// Destructor.
~Fifo();
/// Polls the FIFO for a trigger.
///
/// Returns ProxyTrigger which is the trigger at the head of fifo.
/// Poll and get the trigger at the head.
/// @return ProxyTrigger at the head of the FIFO.
ProxyTrigger poll();
/// Pops a trigger from the FIFO.
/// Remove the head trigger.
void pop();
/// Flushes the tail of the FIFO.
///
/// @param sync If true, waits for the flush to complete before returning.
void flushTail(bool sync = false);
/// Return the FIFO size.
/// @return The FIFO size.
/// Get FIFO size.
/// @return Number of entries in the FIFO.
int size() const;
/// Returns a FifoDeviceHandle object representing the device FIFO.
///
/// @return A FifoDeviceHandle object representing the device FIFO.
/// Get device-side FIFO handle.
/// @return FifoDeviceHandle for device access.
FifoDeviceHandle deviceHandle() const;
private:
struct Impl;
std::unique_ptr<Impl> pimpl;
std::unique_ptr<Impl> pimpl_;
};
} // namespace mscclpp

View File

@@ -15,7 +15,11 @@
namespace mscclpp {
/// A struct representing a pair of 64-bit unsigned integers used as a trigger for the proxy.
#if defined(MSCCLPP_DEVICE_COMPILE)
MSCCLPP_DEVICE_INLINE uint64_t hostLoadRelaxed(uint64_t* ptr) { return atomicLoad(ptr, memoryOrderRelaxed); }
#endif // defined(MSCCLPP_DEVICE_COMPILE)
/// Pair of 64-bit unsigned integers used as a trigger for the proxy.
///
/// This struct is used as a work element in the concurrent FIFO where multiple device threads can push
/// ProxyTrigger elements and a single host proxy thread consumes these work elements.
@@ -45,68 +49,63 @@ struct alignas(16) ProxyTrigger {
struct FifoDeviceHandle {
#if defined(MSCCLPP_DEVICE_COMPILE)
/// Push a trigger to the FIFO.
///
/// @param trigger The trigger to push.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
/// @return The new head of the FIFO.
/// @param trigger Trigger to push.
/// @param maxSpinCount Max spin count before assert. Never assert if negative.
/// @return Previous head of the FIFO where the trigger was pushed.
MSCCLPP_DEVICE_INLINE uint64_t push(ProxyTrigger trigger, [[maybe_unused]] int64_t maxSpinCount = 1000000) {
uint64_t curFifoHead = atomicFetchAdd(this->head, (uint64_t)1, memoryOrderRelaxed);
uint64_t prevHead = atomicFetchAdd<uint64_t, scopeDevice>(head, 1, memoryOrderRelaxed);
// make the last bit intentionally non-zero so that we can safely poll. Don't worry, we will change it back in host
// side
trigger.snd ^= ((uint64_t)1 << (uint64_t)63);
// Flip the last bit for safe polling; host will revert.
constexpr uint64_t flipMask = uint64_t{1} << uint64_t{63};
trigger.snd ^= flipMask;
// Only one of two conditions need to be met to proceed. Either the tail has advanced enough or where we need to
// write to is 0. However, the first condition is faster to check since the tail is flushed periodically anyways but
// for the second condition we need to read CPU memory.
// As atomic access is slow, we first check using the bare pointer and then use the atomic load if the
// condition is not met.
if (curFifoHead >= size + *(this->tailReplica)) {
OR_POLL_MAYBE_JAILBREAK((curFifoHead >= size + atomicLoad(this->tailReplica, memoryOrderRelaxed)),
(atomicLoad(&(this->triggers[curFifoHead % size].fst), memoryOrderRelaxed) != 0),
maxSpinCount);
if (prevHead >= size + *tailReplica) {
OR_POLL_MAYBE_JAILBREAK((prevHead >= size + atomicLoad(tailReplica, memoryOrderRelaxed)),
(hostLoadRelaxed(&(triggers[prevHead % size].fst)) != 0), maxSpinCount);
}
ProxyTrigger* triggerPtr = &(this->triggers[curFifoHead % size]);
ProxyTrigger* triggerPtr = &(triggers[prevHead % size]);
// Make sure the data is visible to the host before we update the tail.
#if defined(MSCCLPP_DEVICE_CUDA)
#if __CUDA_ARCH__ == 800
// For A100, threadfence_system is more efficient than release
// This is faster than release for A100.
__threadfence_system();
asm volatile("st.global.relaxed.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
#else
asm volatile("st.global.release.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
#endif
#else // !defined(MSCCLPP_DEVICE_CUDA)
// store snd no later than fst.
// Store snd no later than fst.
atomicStore(&(triggerPtr->snd), trigger.snd, memoryOrderRelaxed);
atomicStore(&(triggerPtr->fst), trigger.fst, memoryOrderRelease);
#endif // !defined(MSCCLPP_DEVICE_CUDA)
return curFifoHead;
return prevHead;
}
/// Wait until there is a place in the FIFO to push a trigger.
///
/// @param curFifoHead The current head of the FIFO.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
MSCCLPP_DEVICE_INLINE void sync(uint64_t curFifoHead, [[maybe_unused]] int64_t maxSpinCount = 1000000) {
// Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need
/// Wait until a specific trigger is popped from the FIFO.
/// @param fifoHead FIFO head where the trigger was pushed.
/// @param maxSpinCount Max spin count before assert. Never assert if negative.
MSCCLPP_DEVICE_INLINE void sync(uint64_t fifoHead, [[maybe_unused]] int64_t maxSpinCount = 1000000) {
// Same as push but in this case checking the first condition is probably faster since for tail to be pushed we need
// to wait for cudaMemcpy to be done.
OR_POLL_MAYBE_JAILBREAK((curFifoHead >= atomicLoad(this->tailReplica, memoryOrderRelaxed)),
(atomicLoad(&(this->triggers[curFifoHead % size].fst), memoryOrderRelaxed) != 0),
maxSpinCount);
OR_POLL_MAYBE_JAILBREAK((fifoHead >= atomicLoad(tailReplica, memoryOrderRelaxed)),
(hostLoadRelaxed(&(triggers[fifoHead % size].fst)) != 0), maxSpinCount);
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
/// The FIFO buffer that is allocated on the host via `cudaHostAlloc()`.
/// FIFO buffer on host.
ProxyTrigger* triggers;
/// Replica of the FIFO tail.
uint64_t* tailReplica;
/// The FIFO head. Allocated on the device and only accessed by the device.
/// FIFO head on device.
uint64_t* head;
/// The FIFO size.
/// FIFO tail replica on device.
uint64_t* tailReplica;
/// FIFO size.
int size;
};

View File

@@ -123,7 +123,7 @@ namespace detail {
void setReadWriteMemoryAccess(void* base, size_t size);
void* gpuCalloc(size_t bytes);
void* gpuCallocHost(size_t bytes);
void* gpuCallocHost(size_t bytes, unsigned int flags);
#if defined(__HIP_PLATFORM_AMD__)
void* gpuCallocUncached(size_t bytes);
#endif // defined(__HIP_PLATFORM_AMD__)
@@ -206,13 +206,13 @@ auto gpuCallocUnique(size_t nelems = 1) {
}
template <class T>
auto gpuCallocHostShared(size_t nelems = 1) {
return detail::safeAlloc<T, detail::GpuHostDeleter<T>, std::shared_ptr<T>>(detail::gpuCallocHost, nelems);
auto gpuCallocHostShared(size_t nelems = 1, unsigned int flags = cudaHostAllocMapped) {
return detail::safeAlloc<T, detail::GpuHostDeleter<T>, std::shared_ptr<T>>(detail::gpuCallocHost, nelems, flags);
}
template <class T>
auto gpuCallocHostUnique(size_t nelems = 1) {
return detail::safeAlloc<T, detail::GpuHostDeleter<T>, UniqueGpuHostPtr<T>>(detail::gpuCallocHost, nelems);
auto gpuCallocHostUnique(size_t nelems = 1, unsigned int flags = cudaHostAllocMapped) {
return detail::safeAlloc<T, detail::GpuHostDeleter<T>, UniqueGpuHostPtr<T>>(detail::gpuCallocHost, nelems, flags);
}
#if defined(__HIP_PLATFORM_AMD__)

View File

@@ -35,12 +35,6 @@ struct BaseMemoryChannelDeviceHandle {
///
MSCCLPP_DEVICE_INLINE void relaxedSignal() { semaphore_.relaxedSignal(); }
/// Increase the counter of the local semaphore.
MSCCLPP_DEVICE_INLINE void semaphoreIncrement() { semaphore_.semaphoreIncrement(); }
/// Read the counter of the local semaphore.
MSCCLPP_DEVICE_INLINE uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); }
/// Check if the remote semaphore has signaled.
/// @return true if the remote semaphore has signaled.
MSCCLPP_DEVICE_INLINE bool poll() { return semaphore_.poll(); }

View File

@@ -27,8 +27,8 @@ class BaseProxyService {
class ProxyService : public BaseProxyService {
public:
/// Constructor.
/// @param fifoSize The size of the FIFO used by the proxy service. Default is DEFAULT_FIFO_SIZE.
ProxyService(size_t fifoSize = DEFAULT_FIFO_SIZE);
/// @param fifoSize Size of the FIFO used by the proxy service (default: DEFAULT_FIFO_SIZE).
ProxyService(int fifoSize = DEFAULT_FIFO_SIZE);
/// Build and add a semaphore to the proxy service.
/// @param connection The connection associated with the semaphore.
@@ -72,10 +72,7 @@ class ProxyService : public BaseProxyService {
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> memories_;
std::shared_ptr<Proxy> proxy_;
int deviceNumaNode;
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests;
void bindThread();
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests_;
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw);
};

View File

@@ -11,53 +11,51 @@
namespace mscclpp {
/// Possible return values of a ProxyHandler.
/// Return values for ProxyHandler.
enum class ProxyHandlerResult {
/// Move to the next trigger in the FIFO.
/// Move to next trigger in FIFO.
Continue,
/// Flush the FIFO and continue to the next trigger.
/// Flush the FIFO and move to next trigger.
FlushFifoTailAndContinue,
/// Stop the proxy and exit.
/// Stop and exit proxy.
Stop,
};
class Proxy;
/// Type of handler function for the proxy.
/// Handler function type for proxy.
using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
/// Host-side proxy for PortChannels.
class Proxy {
public:
/// Constructor of Proxy.
/// @param handler The handler function to be called for each trigger in the FIFO.
/// @param threadInit Optional function to be called in the proxy thread before starting the FIFO consumption.
/// @param fifoSize The size of the FIFO. Default is DEFAULT_FIFO_SIZE.
Proxy(ProxyHandler handler, std::function<void()> threadInit, size_t fifoSize = DEFAULT_FIFO_SIZE);
/// Constructor.
/// @param handler Handler for each FIFO trigger.
/// @param threadInit Optional function run in proxy thread before FIFO consumption.
/// @param fifoSize FIFO size (default: DEFAULT_FIFO_SIZE).
Proxy(ProxyHandler handler, std::function<void()> threadInit, int fifoSize = DEFAULT_FIFO_SIZE);
/// Constructor of Proxy.
/// @param handler The handler function to be called for each trigger in the FIFO.
/// @param fifoSize The size of the FIFO. Default is DEFAULT_FIFO_SIZE.
Proxy(ProxyHandler handler, size_t fifoSize = DEFAULT_FIFO_SIZE);
/// Constructor.
/// @param handler Handler for each FIFO trigger.
/// @param fifoSize FIFO size (default: DEFAULT_FIFO_SIZE).
Proxy(ProxyHandler handler, int fifoSize = DEFAULT_FIFO_SIZE);
/// Destructor of Proxy.
/// This will stop the proxy if it is running.
/// Destructor. Stops proxy if running.
~Proxy();
/// Start the proxy.
/// Start proxy.
void start();
/// Stop the proxy.
/// Stop proxy.
void stop();
/// This is a concurrent fifo which is multiple threads from the device
/// can produce for and the sole proxy thread consumes it.
/// @return A reference to the FIFO object used by the proxy.
Fifo& fifo();
/// Get reference to FIFO used by proxy.
/// @return Shared pointer to FIFO.
std::shared_ptr<Fifo> fifo();
private:
struct Impl;
std::unique_ptr<Impl> pimpl;
std::unique_ptr<Impl> pimpl_;
};
} // namespace mscclpp

View File

@@ -64,7 +64,7 @@ class BaseSemaphore {
};
/// A semaphore for sending signals from the host to the device.
class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::default_delete> {
class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, detail::GpuHostDeleter> {
private:
std::shared_ptr<Connection> connection_;

View File

@@ -19,16 +19,33 @@ struct Host2DeviceSemaphoreDeviceHandle {
/// Poll if the host has signaled.
/// @return true if the host has signaled.
MSCCLPP_DEVICE_INLINE bool poll() {
bool signaled = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) > (*expectedInboundSemaphoreId));
if (signaled) (*expectedInboundSemaphoreId) += 1;
bool signaled = (loadInbound() > loadExpectedInbound());
if (signaled) incExpectedInbound();
return signaled;
}
/// Wait for the host to signal.
MSCCLPP_DEVICE_INLINE void wait([[maybe_unused]] int64_t maxSpinCount = 100000000) {
(*expectedInboundSemaphoreId) += 1;
uint64_t flag = (*expectedInboundSemaphoreId);
POLL_MAYBE_JAILBREAK((atomicLoad(inboundSemaphoreId, memoryOrderAcquire) < flag), maxSpinCount);
auto expected = incExpectedInbound();
POLL_MAYBE_JAILBREAK((loadInbound() < expected), maxSpinCount);
}
/// Thread-safe read of expected inbound value.
/// @return The expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
}
/// Thread-safe increment of expected inbound value.
/// @return The incremented expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
}
/// Thread-safe read of inbound value.
/// @return The inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
@@ -43,67 +60,72 @@ struct Host2DeviceSemaphoreDeviceHandle {
/// Device-side handle for MemoryDevice2DeviceSemaphore.
struct MemoryDevice2DeviceSemaphoreDeviceHandle {
#if defined(MSCCLPP_DEVICE_COMPILE)
/// Poll if the remote device has signaled.
/// @return true if the remote device has signaled.
/// Poll if remote device has signaled.
/// @return true if remote device has signaled.
MSCCLPP_DEVICE_INLINE bool poll() {
bool signaled = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) > (*expectedInboundSemaphoreId));
if (signaled) (*expectedInboundSemaphoreId) += 1;
bool signaled = (loadInbound() > loadExpectedInbound());
if (signaled) incExpectedInbound();
return signaled;
}
/// Wait for the remote device to signal.
/// Wait for remote device to signal.
MSCCLPP_DEVICE_INLINE void wait([[maybe_unused]] int64_t maxSpinCount = 100000000) {
(*expectedInboundSemaphoreId) += 1;
uint64_t flag = (*expectedInboundSemaphoreId);
POLL_MAYBE_JAILBREAK((atomicLoad(inboundSemaphoreId, memoryOrderAcquire) < flag), maxSpinCount);
auto expected = incExpectedInbound();
POLL_MAYBE_JAILBREAK((loadInbound() < expected), maxSpinCount);
}
/// Wait for the remote device to signal.
///
/// This function is a relaxed version of Wait() and provides no guarantee on the completion of memory operations.
/// User requires to call proper fencing before using this function.
///
/// Relaxed wait; no memory completion guarantee. Use it only for synchronizing execution, not data.
MSCCLPP_DEVICE_INLINE void relaxedWait([[maybe_unused]] int64_t maxSpinCount = 100000000) {
(*expectedInboundSemaphoreId) += 1;
uint64_t flag = (*expectedInboundSemaphoreId);
POLL_MAYBE_JAILBREAK((atomicLoad(inboundSemaphoreId, memoryOrderRelaxed) < flag), maxSpinCount);
auto expected = incExpectedInbound();
POLL_MAYBE_JAILBREAK((loadInbound() < expected), maxSpinCount);
}
/// Signal the remote device.
///
/// This function guarantees that all the memory operation before this function is completed before the remote
/// semaphore is signaled.
///
/// Signal remote device, ensures prior memory ops complete.
MSCCLPP_DEVICE_INLINE void signal() {
// This fence ensures that preceding writes are visible on the peer GPU before the incremented
// `outboundSemaphoreId` is visible.
semaphoreIncrement();
// use memoryOrderSeqCst instead of memoryOrderRelease since memoryOrderSeqCst
// is more efficient on A100.
#if __CUDA_ARCH__ == 800
atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderSeqCst);
auto outbound = incOutbound();
#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ == 800)
// Using memoryOrderSeqCst is faster for A100.
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderSeqCst);
#else
atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderRelease);
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelease);
#endif
}
/// Signal the remote device.
///
/// This function is a relaxed version of signal() and provides no guarantee on the completion of memory operations.
/// User requires to call proper fencing before using this function.
///
/// Relaxed signal; no memory completion guarantee. Use it only for synchronizing execution, not data.
MSCCLPP_DEVICE_INLINE void relaxedSignal() {
// This fence ensures that preceding writes are visible on the peer GPU before the incremented
// `outboundSemaphoreId` is visible.
semaphoreIncrement();
atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderRelaxed);
auto outbound = incOutbound();
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelaxed);
}
/// Increase the counter of the local semaphore.
MSCCLPP_DEVICE_INLINE void semaphoreIncrement() { *outboundSemaphoreId += 1; }
/// Thread-safe read of expected inbound value.
/// @return The expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
}
/// Get the value of the local semaphore.
MSCCLPP_DEVICE_INLINE uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; }
/// Thread-safe increment of expected inbound value.
/// @return The incremented expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
}
/// Thread-safe read of inbound value.
/// @return The inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
}
/// Thread-safe read of outbound value.
/// @return The outbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadOutbound() {
return atomicLoad<uint64_t, scopeDevice>(outboundSemaphoreId, memoryOrderRelaxed);
}
/// Thread-safe increment of outbound value.
/// @return The incremented outbound value.
MSCCLPP_DEVICE_INLINE uint64_t incOutbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(outboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
/// A local memory space where the remote device will write its semaphore value and the local device will read it.