mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
New semaphore constructors (#559)
More intuitive interfaces for creating semaphores and channels. Also allows channel construction using third-party bootstrappers directly without overriding MSCCL++ Bootstrap.
This commit is contained in:
@@ -21,11 +21,11 @@ namespace mscclpp {
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
|
||||
/// Unique ID for a process. This is a MSCCLPP_UNIQUE_ID_BYTES byte array that uniquely identifies a process.
|
||||
/// Unique ID for initializing the TcpBootstrap.
|
||||
using UniqueId = std::array<uint8_t, MSCCLPP_UNIQUE_ID_BYTES>;
|
||||
|
||||
/// Return a version string.
|
||||
/// @return A string representing the version of MSCCL++ in the format "major.minor.patch".
|
||||
/// @return The MSCCL++ version string in "major.minor.patch" format.
|
||||
std::string version();
|
||||
|
||||
/// Base class for bootstraps.
|
||||
@@ -220,9 +220,6 @@ enum class Transport {
|
||||
NumTransports, // The number of transports.
|
||||
};
|
||||
|
||||
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
|
||||
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};
|
||||
|
||||
namespace detail {
|
||||
const size_t TransportFlagsSize = 12;
|
||||
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
|
||||
@@ -234,119 +231,99 @@ using TransportFlagsBase = std::bitset<TransportFlagsSize>;
|
||||
/// Stores transport flags.
|
||||
class TransportFlags : private detail::TransportFlagsBase {
|
||||
public:
|
||||
/// Default constructor for TransportFlags.
|
||||
/// Constructor.
|
||||
TransportFlags() = default;
|
||||
|
||||
/// Constructor for TransportFlags that takes a Transport enum value.
|
||||
///
|
||||
/// Constructor.
|
||||
/// @param transport The transport to set the flag for.
|
||||
TransportFlags(Transport transport);
|
||||
|
||||
/// Check if a specific transport flag is set.
|
||||
///
|
||||
/// @param transport The transport to check the flag for.
|
||||
/// @return True if the flag is set, false otherwise.
|
||||
bool has(Transport transport) const;
|
||||
|
||||
/// Check if no transport flags are set.
|
||||
///
|
||||
/// @return True if no flags are set, false otherwise.
|
||||
bool none() const;
|
||||
|
||||
/// Check if any transport flags are set.
|
||||
///
|
||||
/// @return True if any flags are set, false otherwise.
|
||||
bool any() const;
|
||||
|
||||
/// Check if all transport flags are set.
|
||||
///
|
||||
/// @return True if all flags are set, false otherwise.
|
||||
bool all() const;
|
||||
|
||||
/// Get the number of transport flags that are set.
|
||||
///
|
||||
/// @return The number of flags that are set.
|
||||
size_t count() const;
|
||||
|
||||
/// Bitwise OR assignment operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the OR operation with.
|
||||
/// @return A reference to the modified TransportFlags.
|
||||
TransportFlags& operator|=(TransportFlags other);
|
||||
|
||||
/// Bitwise OR operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the OR operation with.
|
||||
/// @return A new TransportFlags object with the result of the OR operation.
|
||||
TransportFlags operator|(TransportFlags other) const;
|
||||
|
||||
/// Bitwise OR operator for TransportFlags and Transport.
|
||||
///
|
||||
/// @param transport The Transport to perform the OR operation with.
|
||||
/// @return A new TransportFlags object with the result of the OR operation.
|
||||
TransportFlags operator|(Transport transport) const;
|
||||
|
||||
/// Bitwise AND assignment operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the AND operation with.
|
||||
/// @return A reference to the modified TransportFlags.
|
||||
TransportFlags& operator&=(TransportFlags other);
|
||||
|
||||
/// Bitwise AND operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the AND operation with.
|
||||
/// @return A new TransportFlags object with the result of the AND operation.
|
||||
TransportFlags operator&(TransportFlags other) const;
|
||||
|
||||
/// Bitwise AND operator for TransportFlags and Transport.
|
||||
///
|
||||
/// @param transport The Transport to perform the AND operation with.
|
||||
/// @return A new TransportFlags object with the result of the AND operation.
|
||||
TransportFlags operator&(Transport transport) const;
|
||||
|
||||
/// Bitwise XOR assignment operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the XOR operation with.
|
||||
/// @return A reference to the modified TransportFlags.
|
||||
TransportFlags& operator^=(TransportFlags other);
|
||||
|
||||
/// Bitwise XOR operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the XOR operation with.
|
||||
/// @return A new TransportFlags object with the result of the XOR operation.
|
||||
TransportFlags operator^(TransportFlags other) const;
|
||||
|
||||
/// Bitwise XOR operator for TransportFlags and Transport.
|
||||
///
|
||||
/// @param transport The Transport to perform the XOR operation with.
|
||||
/// @return A new TransportFlags object with the result of the XOR operation.
|
||||
TransportFlags operator^(Transport transport) const;
|
||||
|
||||
/// Bitwise NOT operator for TransportFlags.
|
||||
///
|
||||
/// @return A new TransportFlags object with the result of the NOT operation.
|
||||
TransportFlags operator~() const;
|
||||
|
||||
/// Equality comparison operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to compare with.
|
||||
/// @return True if the two TransportFlags objects are equal, false otherwise.
|
||||
bool operator==(TransportFlags other) const;
|
||||
|
||||
/// Inequality comparison operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to compare with.
|
||||
/// @return True if the two TransportFlags objects are not equal, false otherwise.
|
||||
bool operator!=(TransportFlags other) const;
|
||||
|
||||
/// Convert the TransportFlags object to a bitset representation.
|
||||
///
|
||||
/// @return A detail::TransportFlagsBase object representing the TransportFlags object.
|
||||
detail::TransportFlagsBase toBitset() const;
|
||||
|
||||
private:
|
||||
/// Private constructor for TransportFlags that takes a bitset representation.
|
||||
///
|
||||
/// @param bitset The bitset representation of the TransportFlags object.
|
||||
TransportFlags(detail::TransportFlagsBase bitset);
|
||||
};
|
||||
@@ -378,47 +355,64 @@ inline TransportFlags operator^(Transport transport1, Transport transport2) {
|
||||
return TransportFlags(transport1) ^ transport2;
|
||||
}
|
||||
|
||||
/// Available device types.
|
||||
enum class DeviceType {
|
||||
Unknown, // Unknown device type.
|
||||
CPU, // CPU device type.
|
||||
GPU, // GPU device type.
|
||||
};
|
||||
|
||||
struct Device {
|
||||
/// Constructor.
|
||||
Device() = default;
|
||||
|
||||
/// Constructor.
|
||||
/// @param type Device type.
|
||||
/// @param id Device ID. Default is -1 (no specific ID).
|
||||
Device(DeviceType type, int id = -1) : type(type), id(id) {}
|
||||
|
||||
/// Device Type.
|
||||
DeviceType type;
|
||||
|
||||
/// Device ID.
|
||||
int id;
|
||||
};
|
||||
|
||||
class Context;
|
||||
class Connection;
|
||||
|
||||
/// Represents a block of memory that has been registered to a Context.
|
||||
/// Block of memory that has been registered to a Context.
|
||||
/// RegisteredMemory does not own the memory it points to, but it provides a way to transfer metadata about the memory
|
||||
/// to other processes, hence allowing their access to the memory block.
|
||||
class RegisteredMemory {
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
RegisteredMemory() = default;
|
||||
|
||||
/// Destructor.
|
||||
~RegisteredMemory();
|
||||
|
||||
/// Get a pointer to the memory block.
|
||||
///
|
||||
/// @return A pointer to the memory block.
|
||||
void* data() const;
|
||||
|
||||
/// Get a pointer to the original memory block.
|
||||
///
|
||||
/// @return A pointer to the original memory block.
|
||||
void* originalDataPtr() const;
|
||||
|
||||
/// Get the size of the memory block.
|
||||
///
|
||||
/// @return The size of the memory block.
|
||||
size_t size() const;
|
||||
|
||||
/// Get the transport flags associated with the memory block.
|
||||
///
|
||||
/// @return The transport flags associated with the memory block.
|
||||
TransportFlags transports() const;
|
||||
|
||||
/// Serialize the RegisteredMemory object to a vector of characters.
|
||||
///
|
||||
/// @return A vector of characters representing the serialized RegisteredMemory object.
|
||||
std::vector<char> serialize() const;
|
||||
|
||||
/// Deserialize a RegisteredMemory object from a vector of characters.
|
||||
///
|
||||
/// @param data A vector of characters representing a serialized RegisteredMemory object.
|
||||
/// @return A deserialized RegisteredMemory object.
|
||||
static RegisteredMemory deserialize(const std::vector<char>& data);
|
||||
@@ -430,31 +424,32 @@ class RegisteredMemory {
|
||||
|
||||
friend class Context;
|
||||
friend class Connection;
|
||||
friend class SemaphoreStub;
|
||||
};
|
||||
|
||||
/// Represents one end of a connection.
|
||||
/// One end of a connection.
|
||||
class Endpoint {
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
Endpoint() = default;
|
||||
|
||||
/// Get the transport used.
|
||||
///
|
||||
/// @return The transport used.
|
||||
Transport transport();
|
||||
Transport transport() const;
|
||||
|
||||
/// Get the device used.
|
||||
/// @return The device used.
|
||||
const Device& device() const;
|
||||
|
||||
/// Get the maximum write queue size.
|
||||
///
|
||||
/// @return The maximum number of write requests that can be queued.
|
||||
int maxWriteQueueSize();
|
||||
int maxWriteQueueSize() const;
|
||||
|
||||
/// Serialize the Endpoint object to a vector of characters.
|
||||
///
|
||||
/// @return A vector of characters representing the serialized Endpoint object.
|
||||
std::vector<char> serialize();
|
||||
std::vector<char> serialize() const;
|
||||
|
||||
/// Deserialize an Endpoint object from a vector of characters.
|
||||
///
|
||||
/// @param data A vector of characters representing a serialized Endpoint object.
|
||||
/// @return A deserialized Endpoint object.
|
||||
static Endpoint deserialize(const std::vector<char>& data);
|
||||
@@ -468,13 +463,13 @@ class Endpoint {
|
||||
friend class Connection;
|
||||
};
|
||||
|
||||
/// Represents a connection between two processes.
|
||||
/// Connection between two processes.
|
||||
class Connection {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param maxWriteQueueSize The maximum number of write requests that can be queued.
|
||||
Connection(std::shared_ptr<Context> context, int maxWriteQueueSize)
|
||||
: context_(context), maxWriteQueueSize_(maxWriteQueueSize){};
|
||||
/// @param localEndpoint The local endpoint of the connection.
|
||||
Connection(std::shared_ptr<Context> context, const Endpoint& localEndpoint)
|
||||
: context_(context), localEndpoint_(localEndpoint), maxWriteQueueSize_(localEndpoint.maxWriteQueueSize()) {}
|
||||
|
||||
/// Destructor.
|
||||
virtual ~Connection() = default;
|
||||
@@ -498,34 +493,35 @@ class Connection {
|
||||
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
|
||||
|
||||
/// Flush any pending writes to the remote process.
|
||||
virtual void flush(int64_t timeoutUsec = 3e7) = 0;
|
||||
/// @param timeoutUsec Timeout in microseconds. Default: -1 (no timeout)
|
||||
virtual void flush(int64_t timeoutUsec = -1) = 0;
|
||||
|
||||
/// Get the transport used by the local process.
|
||||
///
|
||||
/// @return The transport used by the local process.
|
||||
virtual Transport transport() const = 0;
|
||||
|
||||
/// Get the transport used by the remote process.
|
||||
///
|
||||
/// @return The transport used by the remote process.
|
||||
virtual Transport remoteTransport() const = 0;
|
||||
|
||||
/// Get the name of the transport used for this connection
|
||||
///
|
||||
/// @return A string formatted as "localTransportName -> remoteTransportName".
|
||||
std::string getTransportName() const;
|
||||
/// Get the context associated with this connection.
|
||||
/// @return A shared pointer to the context associated with this connection.
|
||||
std::shared_ptr<Context> context() const { return context_; }
|
||||
|
||||
/// Get the maximum write queue size
|
||||
///
|
||||
/// Get the device used by the local endpoint.
|
||||
/// @return The device used by the local endpoint.
|
||||
const Device& localDevice() const;
|
||||
|
||||
/// Get the maximum write queue size.
|
||||
/// @return The maximum number of write requests that can be queued.
|
||||
int getMaxWriteQueueSize() const;
|
||||
|
||||
protected:
|
||||
// Internal methods for getting implementation pointers.
|
||||
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
|
||||
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
|
||||
|
||||
std::shared_ptr<Context> context_;
|
||||
Endpoint localEndpoint_;
|
||||
int maxWriteQueueSize_;
|
||||
};
|
||||
|
||||
@@ -537,6 +533,7 @@ struct EndpointConfig {
|
||||
static const int DefaultMaxWrPerSend = 64;
|
||||
|
||||
Transport transport;
|
||||
Device device;
|
||||
int ibMaxCqSize;
|
||||
int ibMaxCqPollNum;
|
||||
int ibMaxSendWr;
|
||||
@@ -546,15 +543,18 @@ struct EndpointConfig {
|
||||
/// Constructor that takes a transport and sets the other fields to their default values.
|
||||
///
|
||||
/// @param transport The transport to use.
|
||||
/// @param device The device to use.
|
||||
/// @param ibMaxCqSize The maximum completion queue size.
|
||||
/// @param ibMaxCqPollNum The maximum completion queue poll number.
|
||||
/// @param ibMaxSendWr The maximum send work requests.
|
||||
/// @param ibMaxWrPerSend The maximum work requests per send.
|
||||
/// @param maxWriteQueueSize The maximum write queue size.
|
||||
EndpointConfig(Transport transport = Transport::Unknown, int ibMaxCqSize = DefaultMaxCqSize,
|
||||
int ibMaxCqPollNum = DefaultMaxCqPollNum, int ibMaxSendWr = DefaultMaxSendWr,
|
||||
int ibMaxWrPerSend = DefaultMaxWrPerSend, int maxWriteQueueSize = -1)
|
||||
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
|
||||
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
|
||||
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
|
||||
int maxWriteQueueSize = -1)
|
||||
: transport(transport),
|
||||
device(device),
|
||||
ibMaxCqSize(ibMaxCqSize),
|
||||
ibMaxCqPollNum(ibMaxCqPollNum),
|
||||
ibMaxSendWr(ibMaxSendWr),
|
||||
@@ -562,7 +562,7 @@ struct EndpointConfig {
|
||||
maxWriteQueueSize(maxWriteQueueSize) {}
|
||||
};
|
||||
|
||||
/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
|
||||
/// Context for communication. This provides a low-level interface for forming connections in use-cases
|
||||
/// where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server
|
||||
/// connections. Correct use of this class requires external synchronization when finalizing connections with the
|
||||
/// connect() method.
|
||||
@@ -619,6 +619,62 @@ class Context : public std::enable_shared_from_this<Context> {
|
||||
friend class Endpoint;
|
||||
};
|
||||
|
||||
/// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
|
||||
class SemaphoreStub {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param connection A shared pointer to the connection associated with this semaphore.
|
||||
SemaphoreStub(std::shared_ptr<Connection> connection);
|
||||
|
||||
/// Get the memory associated with this semaphore.
|
||||
/// @return A reference to the registered memory for this semaphore.
|
||||
const RegisteredMemory& memory() const;
|
||||
|
||||
/// Serialize into a vector of characters.
|
||||
/// @return A vector of characters representing the serialized SemaphoreStub object.
|
||||
std::vector<char> serialize() const;
|
||||
|
||||
/// Deserialize a SemaphoreStub object from a vector of characters.
|
||||
/// @param data A vector of characters representing a serialized SemaphoreStub object.
|
||||
/// @return A deserialized SemaphoreStub object.
|
||||
static SemaphoreStub deserialize(const std::vector<char>& data);
|
||||
|
||||
protected:
|
||||
struct Impl;
|
||||
SemaphoreStub(std::shared_ptr<Impl> pimpl);
|
||||
std::shared_ptr<Impl> pimpl_;
|
||||
|
||||
friend class Semaphore;
|
||||
};
|
||||
|
||||
/// Semaphore used by channels for synchronization.
|
||||
class Semaphore {
|
||||
public:
|
||||
/// Constructor.
|
||||
Semaphore() = default;
|
||||
|
||||
/// Constructor.
|
||||
/// @param localStub SemaphoreStub allocated on the local process.
|
||||
/// @param remoteStub SemaphoreStub allocated on the remote process.
|
||||
Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub);
|
||||
|
||||
/// Get the connection associated with this semaphore.
|
||||
/// @return A shared pointer to the connection.
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
|
||||
/// Get the local memory associated with this semaphore.
|
||||
/// @return A reference to the local registered memory.
|
||||
const RegisteredMemory& localMemory() const;
|
||||
|
||||
/// Get the remote memory associated with this semaphore.
|
||||
/// @return A reference to the remote registered memory.
|
||||
const RegisteredMemory& remoteMemory() const;
|
||||
|
||||
protected:
|
||||
struct Impl;
|
||||
std::shared_ptr<Impl> pimpl_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using NonblockingFuture [[deprecated("Use std::shared_future instead. This will be removed in a future release.")]] =
|
||||
std::shared_future<T>;
|
||||
@@ -630,20 +686,26 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
|
||||
/// 2. Call registerMemory() to register memory regions that will be used for communication.
|
||||
/// 3. Call sendMemory() or recvMemory() to send/receive registered memory regions to/from
|
||||
/// other processes.
|
||||
/// 4. Call get() on all futures returned by connect() and recvMemory().
|
||||
/// 5. All done; use connections and registered memories to build channels.
|
||||
/// 4. Call get() on futures returned by connect(). Use the returned connections to create flags.
|
||||
/// 5. Call buildSemaphore() to create a Semaphore out of the flags.
|
||||
/// 6. Call get() on all futures returned by buildSemaphore() and recvMemory().
|
||||
/// 7. All done; use semaphores and registered memories to build channels.
|
||||
///
|
||||
/// CAUTION: The order of method calls matters when the same remote rank and tags are used. That is, the i-th
|
||||
/// "sending" method call (sendMemory(), connect(), and buildSemaphore()) on the local rank must be matched
|
||||
/// by the i-th "receiving" method call (recvMemory(), connect(), and buildSemaphore()) on the remote rank.
|
||||
///
|
||||
/// Correct Example 1:
|
||||
/// ```cpp
|
||||
/// // Rank 0
|
||||
/// communicator.sendMemory(memory1, 1, tag);
|
||||
/// communicator.sendMemory(memory2, 1, tag);
|
||||
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
|
||||
/// connection.get(); // This will return the connection.
|
||||
/// // Rank 1
|
||||
/// auto mem1 = communicator.recvMemory(0, tag);
|
||||
/// auto mem2 = communicator.recvMemory(0, tag);
|
||||
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
|
||||
/// mem2.get(); // This will return memory2.
|
||||
/// connection.get(); // This will return the connection.
|
||||
/// mem1.get(); // This will return memory1.
|
||||
@@ -654,13 +716,13 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
|
||||
/// // Rank 0
|
||||
/// communicator.sendMemory(memory0, 1, tag);
|
||||
/// auto mem1 = communicator.recvMemory(1, tag);
|
||||
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
|
||||
/// connection.get(); // This will return the connection.
|
||||
/// mem1.get(); // This will return memory1.
|
||||
/// // Rank 1
|
||||
/// auto mem0 = communicator.recvMemory(0, tag);
|
||||
/// communicator.sendMemory(memory1, 0, tag);
|
||||
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
|
||||
/// mem0.get(); // This will return memory0.
|
||||
/// connection.get(); // This will return the connection.
|
||||
/// ```
|
||||
@@ -670,10 +732,10 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
|
||||
/// // Rank 0
|
||||
/// communicator.sendMemory(memory0, 1, tag);
|
||||
/// auto mem1 = communicator.recvMemory(1, tag);
|
||||
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
|
||||
/// // Rank 1
|
||||
/// auto mem0 = communicator.recvMemory(0, tag);
|
||||
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc); // undefined behavior
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag); // undefined behavior
|
||||
/// communicator.sendMemory(memory1, 0, tag);
|
||||
/// ```
|
||||
/// In the wrong example, the connection information from rank 1 will be sent to the `mem1` object on rank 0,
|
||||
@@ -691,12 +753,10 @@ class Communicator {
|
||||
~Communicator();
|
||||
|
||||
/// Returns the bootstrap held by this communicator.
|
||||
///
|
||||
/// @return The bootstrap held by this communicator.
|
||||
std::shared_ptr<Bootstrap> bootstrap();
|
||||
|
||||
/// Returns the context held by this communicator.
|
||||
///
|
||||
/// @return The context held by this communicator.
|
||||
std::shared_ptr<Context> context();
|
||||
|
||||
@@ -723,7 +783,7 @@ class Communicator {
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the send.
|
||||
///
|
||||
void sendMemory(RegisteredMemory memory, int remoteRank, int tag);
|
||||
void sendMemory(RegisteredMemory memory, int remoteRank, int tag = 0);
|
||||
|
||||
[[deprecated("Use sendMemory() instead. This will be removed in a future release.")]] void sendMemoryOnSetup(
|
||||
RegisteredMemory memory, int remoteRank, int tag) {
|
||||
@@ -752,7 +812,7 @@ class Communicator {
|
||||
/// @param tag The tag to use for identifying the receive.
|
||||
/// @return A future of registered memory.
|
||||
///
|
||||
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag);
|
||||
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag = 0);
|
||||
|
||||
[[deprecated(
|
||||
"Use recvMemory() instead. This will be removed in a future release.")]] NonblockingFuture<RegisteredMemory>
|
||||
@@ -762,7 +822,7 @@ class Communicator {
|
||||
|
||||
/// Connect to a remote rank.
|
||||
///
|
||||
/// This function will start (but not be waiting for) sending metadata about the local endpoint to the remote rank,
|
||||
/// This function will start (but not wait for) sending metadata about the local endpoint to the remote rank,
|
||||
/// and return a future connection without waiting for the remote rank to respond.
|
||||
/// The connection will be established when the remote rank responds with its own endpoint and the local rank calls
|
||||
/// the first get() on the future.
|
||||
@@ -783,19 +843,30 @@ class Communicator {
|
||||
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
|
||||
/// back to back.
|
||||
///
|
||||
/// @param localConfig The configuration for the local endpoint.
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the send and receive.
|
||||
/// @param localConfig The configuration for the local endpoint.
|
||||
/// @return A future of shared pointer to the connection.
|
||||
///
|
||||
std::shared_future<std::shared_ptr<Connection>> connect(int remoteRank, int tag, EndpointConfig localConfig);
|
||||
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0);
|
||||
|
||||
[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
|
||||
shared_future<std::shared_ptr<Connection>>
|
||||
connect(int remoteRank, int tag, EndpointConfig localConfig);
|
||||
|
||||
[[deprecated("Use connect() instead. This will be removed in a future release.")]] NonblockingFuture<
|
||||
std::shared_ptr<Connection>>
|
||||
connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return connect(remoteRank, tag, localConfig);
|
||||
return connect(localConfig, remoteRank, tag);
|
||||
}
|
||||
|
||||
/// Build a semaphore for cross-process synchronization.
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the operation.
|
||||
/// @return A future of the built semaphore.
|
||||
std::shared_future<Semaphore> buildSemaphore(std::shared_ptr<Connection> connection, int remoteRank, int tag = 0);
|
||||
|
||||
/// Get the remote rank a connection is connected to.
|
||||
///
|
||||
/// @param connection The connection to get the remote rank for.
|
||||
@@ -842,6 +913,10 @@ using PacketPayload = typename T::Payload;
|
||||
|
||||
namespace std {
|
||||
|
||||
std::string to_string(const mscclpp::Transport& transport);
|
||||
|
||||
std::string to_string(const mscclpp::Device& device);
|
||||
|
||||
/// Specialization of the std::hash template for mscclpp::TransportFlags.
|
||||
template <>
|
||||
struct hash<mscclpp::TransportFlags>;
|
||||
|
||||
@@ -13,24 +13,24 @@
|
||||
|
||||
/// Throw mscclpp::CudaError if @p cmd does not return cudaSuccess.
|
||||
/// @param cmd The command to execute.
|
||||
#define MSCCLPP_CUDATHROW(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
throw mscclpp::CudaError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
|
||||
err); \
|
||||
} \
|
||||
#define MSCCLPP_CUDATHROW(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
throw ::mscclpp::CudaError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
|
||||
err); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
/// Throw mscclpp::CuError if @p cmd does not return CUDA_SUCCESS.
|
||||
/// @param cmd The command to execute.
|
||||
#define MSCCLPP_CUTHROW(cmd) \
|
||||
do { \
|
||||
CUresult err = cmd; \
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
throw mscclpp::CuError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
|
||||
err); \
|
||||
} \
|
||||
#define MSCCLPP_CUTHROW(cmd) \
|
||||
do { \
|
||||
CUresult err = cmd; \
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
throw ::mscclpp::CuError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
|
||||
err); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -18,13 +18,19 @@ struct BaseMemoryChannel {
|
||||
std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore_;
|
||||
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
BaseMemoryChannel() = default;
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphore The semaphore used to synchronize the communication.
|
||||
/// @param semaphore Semaphore used to synchronize the communication.
|
||||
BaseMemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphore Semaphore used to synchronize the communication.
|
||||
BaseMemoryChannel(const Semaphore& semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param other Other BaseMemoryChannel to copy from.
|
||||
BaseMemoryChannel(const BaseMemoryChannel& other) = default;
|
||||
|
||||
BaseMemoryChannel& operator=(BaseMemoryChannel& other) = default;
|
||||
@@ -33,9 +39,8 @@ struct BaseMemoryChannel {
|
||||
using DeviceHandle = BaseMemoryChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the BaseMemoryChannel is not released when using the returned handle.
|
||||
///
|
||||
/// @return The device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
@@ -47,7 +52,7 @@ struct MemoryChannel : public BaseMemoryChannel {
|
||||
void* packetBuffer_;
|
||||
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
MemoryChannel() = default;
|
||||
|
||||
/// Constructor.
|
||||
@@ -59,13 +64,20 @@ struct MemoryChannel : public BaseMemoryChannel {
|
||||
MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore, RegisteredMemory dst, void* src,
|
||||
void* packetBuffer = nullptr);
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphore The semaphore used to synchronize the communication.
|
||||
/// @param dst Registered memory of the destination.
|
||||
/// @param src The source memory address.
|
||||
/// @param packetBuffer A buffer used to store packets. @p packetBuffer is optional and if it is nullptr,
|
||||
/// unpackPacket() and unpackPackets() methods are not available.
|
||||
MemoryChannel(const Semaphore& semaphore, RegisteredMemory dst, void* src, void* packetBuffer = nullptr);
|
||||
|
||||
/// Device-side handle for MemoryChannel.
|
||||
using DeviceHandle = MemoryChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the MemoryChannel is not released when using the returned handle.
|
||||
///
|
||||
/// @return The device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -35,6 +35,11 @@ class ProxyService : public BaseProxyService {
|
||||
/// @return The ID of the semaphore.
|
||||
SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
|
||||
/// Add a semaphore to the proxy service.
|
||||
/// @param semaphore The semaphore to be added
|
||||
/// @return The ID of the semaphore.
|
||||
SemaphoreId addSemaphore(const Semaphore& semaphore);
|
||||
|
||||
/// Add a semaphore to the proxy service.
|
||||
/// @param semaphore The semaphore to be added
|
||||
/// @return The ID of the semaphore.
|
||||
@@ -87,22 +92,36 @@ struct BasePortChannel {
|
||||
std::shared_ptr<Proxy> proxy_;
|
||||
|
||||
public:
|
||||
/// Constructor.
|
||||
BasePortChannel() = default;
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
/// @param semaphore The semaphore used to synchronize the communication.
|
||||
/// @param proxy The proxy used for communication.
|
||||
BasePortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore,
|
||||
std::shared_ptr<Proxy> proxy);
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
/// @param semaphore The semaphore used to synchronize the communication.
|
||||
/// @param proxy The proxy used for communication.
|
||||
BasePortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore, std::shared_ptr<Proxy> proxy);
|
||||
|
||||
/// Copy constructor.
|
||||
/// @param other The other BasePortChannel to copy from.
|
||||
BasePortChannel(const BasePortChannel& other) = default;
|
||||
|
||||
/// Assignment operator.
|
||||
/// @param other The other BasePortChannel to assign from.
|
||||
BasePortChannel& operator=(BasePortChannel& other) = default;
|
||||
|
||||
/// Device-side handle for BasePortChannel.
|
||||
using DeviceHandle = BasePortChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the BasePortChannel is not released when using the returned handle.
|
||||
///
|
||||
/// @return The device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
@@ -113,7 +132,7 @@ struct PortChannel : public BasePortChannel {
|
||||
MemoryId src_;
|
||||
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
PortChannel() = default;
|
||||
|
||||
/// Constructor.
|
||||
@@ -125,19 +144,29 @@ struct PortChannel : public BasePortChannel {
|
||||
PortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore, std::shared_ptr<Proxy> proxy,
|
||||
MemoryId dst, MemoryId src);
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
/// @param semaphore The semaphore.
|
||||
/// @param proxy The proxy.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
PortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore, std::shared_ptr<Proxy> proxy, MemoryId dst,
|
||||
MemoryId src);
|
||||
|
||||
/// Copy constructor.
|
||||
/// @param other The other PortChannel to copy from.
|
||||
PortChannel(const PortChannel& other) = default;
|
||||
|
||||
/// Assignment operator.
|
||||
/// @param other The other PortChannel to assign from.
|
||||
PortChannel& operator=(PortChannel& other) = default;
|
||||
|
||||
/// Device-side handle for PortChannel.
|
||||
using DeviceHandle = PortChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the PortChannel is not released when using the returned handle.
|
||||
///
|
||||
/// @return The device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -12,63 +12,18 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
/// A base class for semaphores.
|
||||
///
|
||||
/// A semaphore is a synchronization mechanism that allows the local peer to wait for the remote peer to complete a
|
||||
/// data transfer. The local peer signals the remote peer that it has completed a data transfer by incrementing the
|
||||
/// outbound semaphore ID. The incremented outbound semaphore ID is copied to the remote peer's inbound semaphore ID so
|
||||
/// that the remote peer can wait for the local peer to complete a data transfer. Vice versa, the remote peer signals
|
||||
/// the local peer that it has completed a data transfer by incrementing the remote peer's outbound semaphore ID and
|
||||
/// copying the incremented value to the local peer's inbound semaphore ID.
|
||||
///
|
||||
/// @tparam InboundDeleter The deleter for inbound semaphore IDs. This is either `std::default_delete` for host memory
|
||||
/// or CudaDeleter for device memory.
|
||||
/// @tparam OutboundDeleter The deleter for outbound semaphore IDs. This is either `std::default_delete` for host memory
|
||||
/// or CudaDeleter for device memory.
|
||||
///
|
||||
template <template <typename> typename InboundDeleter, template <typename> typename OutboundDeleter>
|
||||
class BaseSemaphore {
|
||||
protected:
|
||||
/// The registered memory for the remote peer's inbound semaphore ID.
|
||||
std::shared_future<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;
|
||||
|
||||
/// The inbound semaphore ID that is incremented by the remote peer and waited on by the local peer.
|
||||
///
|
||||
/// The location of localInboundSemaphore_ can be either on the host or on the device.
|
||||
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> localInboundSemaphore_;
|
||||
|
||||
/// The expected inbound semaphore ID to be incremented by the local peer and compared to the
|
||||
/// localInboundSemaphore_.
|
||||
///
|
||||
/// The location of expectedInboundSemaphore_ can be either on the host or on the device.
|
||||
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> expectedInboundSemaphore_;
|
||||
|
||||
/// The outbound semaphore ID that is incremented by the local peer and copied to the remote peer's
|
||||
/// localInboundSemaphore_.
|
||||
///
|
||||
/// The location of outboundSemaphore_ can be either on the host or on the device.
|
||||
std::unique_ptr<uint64_t, OutboundDeleter<uint64_t>> outboundSemaphore_;
|
||||
|
||||
public:
|
||||
/// Constructs a BaseSemaphore.
|
||||
///
|
||||
/// @param localInboundSemaphoreId The inbound semaphore ID
|
||||
/// @param expectedInboundSemaphoreId The expected inbound semaphore ID
|
||||
/// @param outboundSemaphoreId The outbound semaphore ID
|
||||
BaseSemaphore(std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> localInboundSemaphoreId,
|
||||
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> expectedInboundSemaphoreId,
|
||||
std::unique_ptr<uint64_t, OutboundDeleter<uint64_t>> outboundSemaphoreId)
|
||||
: localInboundSemaphore_(std::move(localInboundSemaphoreId)),
|
||||
expectedInboundSemaphore_(std::move(expectedInboundSemaphoreId)),
|
||||
outboundSemaphore_(std::move(outboundSemaphoreId)) {}
|
||||
};
|
||||
|
||||
/// A semaphore for sending signals from the host to the device.
|
||||
class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::default_delete> {
|
||||
class Host2DeviceSemaphore {
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
Semaphore semaphore_;
|
||||
detail::UniqueGpuPtr<uint64_t> expectedInboundToken_;
|
||||
std::unique_ptr<uint64_t> outboundToken_;
|
||||
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param semaphore
|
||||
Host2DeviceSemaphore(const Semaphore& semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param communicator The communicator.
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
@@ -76,7 +31,7 @@ class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::defau
|
||||
|
||||
/// Returns the connection.
|
||||
/// @return The connection associated with this semaphore.
|
||||
std::shared_ptr<Connection> connection();
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
|
||||
/// Signal the device.
|
||||
void signal();
|
||||
@@ -85,13 +40,22 @@ class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::defau
|
||||
using DeviceHandle = Host2DeviceSemaphoreDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
DeviceHandle deviceHandle();
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
/// A semaphore for sending signals from the local host to a remote host.
|
||||
class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::default_delete> {
|
||||
class Host2HostSemaphore {
|
||||
private:
|
||||
Semaphore semaphore_;
|
||||
std::unique_ptr<uint64_t> expectedInboundToken_;
|
||||
std::unique_ptr<uint64_t> outboundToken_;
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
/// Constructor.
|
||||
/// @param semaphore
|
||||
Host2HostSemaphore(const Semaphore& semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param communicator The communicator.
|
||||
/// @param connection The connection associated with this semaphore. Transport::CudaIpc is not allowed for
|
||||
/// Host2HostSemaphore.
|
||||
@@ -99,7 +63,7 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
|
||||
|
||||
/// Returns the connection.
|
||||
/// @return The connection associated with this semaphore.
|
||||
std::shared_ptr<Connection> connection();
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
|
||||
/// Signal the remote host.
|
||||
void signal();
|
||||
@@ -111,29 +75,34 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
|
||||
/// Wait for the remote host to signal.
|
||||
/// @param maxSpinCount The maximum number of spin counts before throwing an exception. Never throws if negative.
|
||||
void wait(int64_t maxSpinCount = 10000000);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
};
|
||||
|
||||
/// A semaphore for sending signals from the local device to a peer device via a GPU thread.
|
||||
class MemoryDevice2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, detail::GpuDeleter> {
|
||||
class MemoryDevice2DeviceSemaphore {
|
||||
private:
|
||||
Semaphore semaphore_;
|
||||
detail::UniqueGpuPtr<uint64_t> expectedInboundToken_;
|
||||
detail::UniqueGpuPtr<uint64_t> outboundToken_;
|
||||
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param semaphore
|
||||
MemoryDevice2DeviceSemaphore(const Semaphore& semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param communicator The communicator.
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
MemoryDevice2DeviceSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
|
||||
/// Constructor.
|
||||
MemoryDevice2DeviceSemaphore() = delete;
|
||||
/// Returns the connection.
|
||||
/// @return The connection associated with this semaphore.
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
|
||||
/// Device-side handle for MemoryDevice2DeviceSemaphore.
|
||||
using DeviceHandle = MemoryDevice2DeviceSemaphoreDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
|
||||
bool isRemoteInboundSemaphoreIdSet_;
|
||||
};
|
||||
|
||||
/// @deprecated Use MemoryDevice2DeviceSemaphore instead.
|
||||
|
||||
@@ -33,28 +33,28 @@ struct Host2DeviceSemaphoreDeviceHandle {
|
||||
/// Thread-safe read of expected inbound value.
|
||||
/// @return The expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
|
||||
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
|
||||
return atomicLoad<uint64_t, scopeDevice>(expectedInboundToken, memoryOrderRelaxed);
|
||||
}
|
||||
|
||||
/// Thread-safe increment of expected inbound value.
|
||||
/// @return The incremented expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundToken, 1, memoryOrderRelaxed) + 1;
|
||||
}
|
||||
|
||||
/// Thread-safe read of inbound value.
|
||||
/// @return The inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
|
||||
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
|
||||
return atomicLoad<uint64_t, scopeSystem>(inboundToken, memoryOrderAcquire);
|
||||
}
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
/// A local memory space where a host thread (on behalf of the remote device) will write its semaphore value
|
||||
/// and the local device will read it.
|
||||
uint64_t* inboundSemaphoreId;
|
||||
uint64_t* inboundToken;
|
||||
|
||||
/// A local memory space where the local device stores the expected value of the inboundSemaphoreId to wait for.
|
||||
uint64_t* expectedInboundSemaphoreId;
|
||||
/// A local memory space where the local device stores the expected value of the inboundToken to wait for.
|
||||
uint64_t* expectedInboundToken;
|
||||
};
|
||||
|
||||
/// Device-side handle for MemoryDevice2DeviceSemaphore.
|
||||
@@ -85,61 +85,61 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle {
|
||||
auto outbound = incOutbound();
|
||||
#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ == 800)
|
||||
// Using memoryOrderSeqCst is faster for A100.
|
||||
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderSeqCst);
|
||||
atomicStore(remoteInboundToken, outbound, memoryOrderSeqCst);
|
||||
#else
|
||||
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelease);
|
||||
atomicStore(remoteInboundToken, outbound, memoryOrderRelease);
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Relaxed signal; no memory completion guarantee. Use it only for synchronizing execution, not data.
|
||||
MSCCLPP_DEVICE_INLINE void relaxedSignal() {
|
||||
auto outbound = incOutbound();
|
||||
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelaxed);
|
||||
atomicStore(remoteInboundToken, outbound, memoryOrderRelaxed);
|
||||
}
|
||||
|
||||
/// Thread-safe read of expected inbound value.
|
||||
/// @return The expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
|
||||
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
|
||||
return atomicLoad<uint64_t, scopeDevice>(expectedInboundToken, memoryOrderRelaxed);
|
||||
}
|
||||
|
||||
/// Thread-safe increment of expected inbound value.
|
||||
/// @return The incremented expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundToken, 1, memoryOrderRelaxed) + 1;
|
||||
}
|
||||
|
||||
/// Thread-safe read of inbound value.
|
||||
/// @return The inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
|
||||
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
|
||||
return atomicLoad<uint64_t, scopeSystem>(inboundToken, memoryOrderAcquire);
|
||||
}
|
||||
|
||||
/// Thread-safe read of outbound value.
|
||||
/// @return The outbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadOutbound() {
|
||||
return atomicLoad<uint64_t, scopeDevice>(outboundSemaphoreId, memoryOrderRelaxed);
|
||||
return atomicLoad<uint64_t, scopeDevice>(outboundToken, memoryOrderRelaxed);
|
||||
}
|
||||
|
||||
/// Thread-safe increment of outbound value.
|
||||
/// @return The incremented outbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t incOutbound() {
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(outboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(outboundToken, 1, memoryOrderRelaxed) + 1;
|
||||
}
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
/// A local memory space where the remote device will write its semaphore value and the local device will read it.
|
||||
uint64_t* inboundSemaphoreId;
|
||||
uint64_t* inboundToken;
|
||||
|
||||
/// A local memory space where the local device stores the semaphore value to be written to the remote device.
|
||||
uint64_t* outboundSemaphoreId;
|
||||
uint64_t* outboundToken;
|
||||
|
||||
/// A remote memory space where the local device writes its outboundSemaphoreId on. This is inboundSemaphoreId of the
|
||||
/// A remote memory space where the local device writes its outboundToken on. This is inboundToken of the
|
||||
/// remote device.
|
||||
uint64_t* remoteInboundSemaphoreId;
|
||||
uint64_t* remoteInboundToken;
|
||||
|
||||
/// A local memory space where the local device stores the expected value of the inboundSemaphoreId to wait for.
|
||||
uint64_t* expectedInboundSemaphoreId;
|
||||
/// A local memory space where the local device stores the expected value of the inboundToken to wait for.
|
||||
uint64_t* expectedInboundToken;
|
||||
};
|
||||
|
||||
/// @deprecated Use MemoryDevice2DeviceSemaphoreDeviceHandle instead.
|
||||
|
||||
Reference in New Issue
Block a user