New semaphore constructors (#559)

More intuitive interfaces for creating semaphores and channels. Also
allows channel construction using third-party bootstrappers directly
without overriding MSCCL++ Bootstrap.
This commit is contained in:
Changho Hwang
2025-07-11 17:10:46 -07:00
committed by GitHub
parent 20eca28942
commit ae56698d67
42 changed files with 847 additions and 529 deletions

View File

@@ -21,11 +21,11 @@ namespace mscclpp {
#define MSCCLPP_UNIQUE_ID_BYTES 128
/// Unique ID for a process. This is a MSCCLPP_UNIQUE_ID_BYTES byte array that uniquely identifies a process.
/// Unique ID for initializing the TcpBootstrap.
using UniqueId = std::array<uint8_t, MSCCLPP_UNIQUE_ID_BYTES>;
/// Return a version string.
/// @return A string representing the version of MSCCL++ in the format "major.minor.patch".
/// @return The MSCCL++ version string in "major.minor.patch" format.
std::string version();
/// Base class for bootstraps.
@@ -220,9 +220,6 @@ enum class Transport {
NumTransports, // The number of transports.
};
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};
namespace detail {
const size_t TransportFlagsSize = 12;
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
@@ -234,119 +231,99 @@ using TransportFlagsBase = std::bitset<TransportFlagsSize>;
/// Stores transport flags.
class TransportFlags : private detail::TransportFlagsBase {
public:
/// Default constructor for TransportFlags.
/// Constructor.
TransportFlags() = default;
/// Constructor for TransportFlags that takes a Transport enum value.
///
/// Constructor.
/// @param transport The transport to set the flag for.
TransportFlags(Transport transport);
/// Check if a specific transport flag is set.
///
/// @param transport The transport to check the flag for.
/// @return True if the flag is set, false otherwise.
bool has(Transport transport) const;
/// Check if no transport flags are set.
///
/// @return True if no flags are set, false otherwise.
bool none() const;
/// Check if any transport flags are set.
///
/// @return True if any flags are set, false otherwise.
bool any() const;
/// Check if all transport flags are set.
///
/// @return True if all flags are set, false otherwise.
bool all() const;
/// Get the number of transport flags that are set.
///
/// @return The number of flags that are set.
size_t count() const;
/// Bitwise OR assignment operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the OR operation with.
/// @return A reference to the modified TransportFlags.
TransportFlags& operator|=(TransportFlags other);
/// Bitwise OR operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the OR operation with.
/// @return A new TransportFlags object with the result of the OR operation.
TransportFlags operator|(TransportFlags other) const;
/// Bitwise OR operator for TransportFlags and Transport.
///
/// @param transport The Transport to perform the OR operation with.
/// @return A new TransportFlags object with the result of the OR operation.
TransportFlags operator|(Transport transport) const;
/// Bitwise AND assignment operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the AND operation with.
/// @return A reference to the modified TransportFlags.
TransportFlags& operator&=(TransportFlags other);
/// Bitwise AND operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the AND operation with.
/// @return A new TransportFlags object with the result of the AND operation.
TransportFlags operator&(TransportFlags other) const;
/// Bitwise AND operator for TransportFlags and Transport.
///
/// @param transport The Transport to perform the AND operation with.
/// @return A new TransportFlags object with the result of the AND operation.
TransportFlags operator&(Transport transport) const;
/// Bitwise XOR assignment operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the XOR operation with.
/// @return A reference to the modified TransportFlags.
TransportFlags& operator^=(TransportFlags other);
/// Bitwise XOR operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the XOR operation with.
/// @return A new TransportFlags object with the result of the XOR operation.
TransportFlags operator^(TransportFlags other) const;
/// Bitwise XOR operator for TransportFlags and Transport.
///
/// @param transport The Transport to perform the XOR operation with.
/// @return A new TransportFlags object with the result of the XOR operation.
TransportFlags operator^(Transport transport) const;
/// Bitwise NOT operator for TransportFlags.
///
/// @return A new TransportFlags object with the result of the NOT operation.
TransportFlags operator~() const;
/// Equality comparison operator for TransportFlags.
///
/// @param other The other TransportFlags to compare with.
/// @return True if the two TransportFlags objects are equal, false otherwise.
bool operator==(TransportFlags other) const;
/// Inequality comparison operator for TransportFlags.
///
/// @param other The other TransportFlags to compare with.
/// @return True if the two TransportFlags objects are not equal, false otherwise.
bool operator!=(TransportFlags other) const;
/// Convert the TransportFlags object to a bitset representation.
///
/// @return A detail::TransportFlagsBase object representing the TransportFlags object.
detail::TransportFlagsBase toBitset() const;
private:
/// Private constructor for TransportFlags that takes a bitset representation.
///
/// @param bitset The bitset representation of the TransportFlags object.
TransportFlags(detail::TransportFlagsBase bitset);
};
@@ -378,47 +355,64 @@ inline TransportFlags operator^(Transport transport1, Transport transport2) {
return TransportFlags(transport1) ^ transport2;
}
/// Available device types.
enum class DeviceType {
Unknown, // Unknown device type.
CPU, // CPU device type.
GPU, // GPU device type.
};
struct Device {
/// Constructor.
Device() = default;
/// Constructor.
/// @param type Device type.
/// @param id Device ID. Default is -1 (no specific ID).
Device(DeviceType type, int id = -1) : type(type), id(id) {}
/// Device Type.
DeviceType type;
/// Device ID.
int id;
};
class Context;
class Connection;
/// Represents a block of memory that has been registered to a Context.
/// Block of memory that has been registered to a Context.
/// RegisteredMemory does not own the memory it points to, but it provides a way to transfer metadata about the memory
/// to other processes, hence allowing their access to the memory block.
class RegisteredMemory {
public:
/// Default constructor.
/// Constructor.
RegisteredMemory() = default;
/// Destructor.
~RegisteredMemory();
/// Get a pointer to the memory block.
///
/// @return A pointer to the memory block.
void* data() const;
/// Get a pointer to the original memory block.
///
/// @return A pointer to the original memory block.
void* originalDataPtr() const;
/// Get the size of the memory block.
///
/// @return The size of the memory block.
size_t size() const;
/// Get the transport flags associated with the memory block.
///
/// @return The transport flags associated with the memory block.
TransportFlags transports() const;
/// Serialize the RegisteredMemory object to a vector of characters.
///
/// @return A vector of characters representing the serialized RegisteredMemory object.
std::vector<char> serialize() const;
/// Deserialize a RegisteredMemory object from a vector of characters.
///
/// @param data A vector of characters representing a serialized RegisteredMemory object.
/// @return A deserialized RegisteredMemory object.
static RegisteredMemory deserialize(const std::vector<char>& data);
@@ -430,31 +424,32 @@ class RegisteredMemory {
friend class Context;
friend class Connection;
friend class SemaphoreStub;
};
/// Represents one end of a connection.
/// One end of a connection.
class Endpoint {
public:
/// Default constructor.
/// Constructor.
Endpoint() = default;
/// Get the transport used.
///
/// @return The transport used.
Transport transport();
Transport transport() const;
/// Get the device used.
/// @return The device used.
const Device& device() const;
/// Get the maximum write queue size.
///
/// @return The maximum number of write requests that can be queued.
int maxWriteQueueSize();
int maxWriteQueueSize() const;
/// Serialize the Endpoint object to a vector of characters.
///
/// @return A vector of characters representing the serialized Endpoint object.
std::vector<char> serialize();
std::vector<char> serialize() const;
/// Deserialize an Endpoint object from a vector of characters.
///
/// @param data A vector of characters representing a serialized Endpoint object.
/// @return A deserialized Endpoint object.
static Endpoint deserialize(const std::vector<char>& data);
@@ -468,13 +463,13 @@ class Endpoint {
friend class Connection;
};
/// Represents a connection between two processes.
/// Connection between two processes.
class Connection {
public:
/// Constructor.
/// @param maxWriteQueueSize The maximum number of write requests that can be queued.
Connection(std::shared_ptr<Context> context, int maxWriteQueueSize)
: context_(context), maxWriteQueueSize_(maxWriteQueueSize){};
/// @param localEndpoint The local endpoint of the connection.
Connection(std::shared_ptr<Context> context, const Endpoint& localEndpoint)
: context_(context), localEndpoint_(localEndpoint), maxWriteQueueSize_(localEndpoint.maxWriteQueueSize()) {}
/// Destructor.
virtual ~Connection() = default;
@@ -498,34 +493,35 @@ class Connection {
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
/// Flush any pending writes to the remote process.
virtual void flush(int64_t timeoutUsec = 3e7) = 0;
/// @param timeoutUsec Timeout in microseconds. Default: -1 (no timeout)
virtual void flush(int64_t timeoutUsec = -1) = 0;
/// Get the transport used by the local process.
///
/// @return The transport used by the local process.
virtual Transport transport() const = 0;
/// Get the transport used by the remote process.
///
/// @return The transport used by the remote process.
virtual Transport remoteTransport() const = 0;
/// Get the name of the transport used for this connection
///
/// @return A string formatted as "localTransportName -> remoteTransportName".
std::string getTransportName() const;
/// Get the context associated with this connection.
/// @return A shared pointer to the context associated with this connection.
std::shared_ptr<Context> context() const { return context_; }
/// Get the maximum write queue size
///
/// Get the device used by the local endpoint.
/// @return The device used by the local endpoint.
const Device& localDevice() const;
/// Get the maximum write queue size.
/// @return The maximum number of write requests that can be queued.
int getMaxWriteQueueSize() const;
protected:
// Internal methods for getting implementation pointers.
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
std::shared_ptr<Context> context_;
Endpoint localEndpoint_;
int maxWriteQueueSize_;
};
@@ -537,6 +533,7 @@ struct EndpointConfig {
static const int DefaultMaxWrPerSend = 64;
Transport transport;
Device device;
int ibMaxCqSize;
int ibMaxCqPollNum;
int ibMaxSendWr;
@@ -546,15 +543,18 @@ struct EndpointConfig {
/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param transport The transport to use.
/// @param device The device to use.
/// @param ibMaxCqSize The maximum completion queue size.
/// @param ibMaxCqPollNum The maximum completion queue poll number.
/// @param ibMaxSendWr The maximum send work requests.
/// @param ibMaxWrPerSend The maximum work requests per send.
/// @param maxWriteQueueSize The maximum write queue size.
EndpointConfig(Transport transport = Transport::Unknown, int ibMaxCqSize = DefaultMaxCqSize,
int ibMaxCqPollNum = DefaultMaxCqPollNum, int ibMaxSendWr = DefaultMaxSendWr,
int ibMaxWrPerSend = DefaultMaxWrPerSend, int maxWriteQueueSize = -1)
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
int maxWriteQueueSize = -1)
: transport(transport),
device(device),
ibMaxCqSize(ibMaxCqSize),
ibMaxCqPollNum(ibMaxCqPollNum),
ibMaxSendWr(ibMaxSendWr),
@@ -562,7 +562,7 @@ struct EndpointConfig {
maxWriteQueueSize(maxWriteQueueSize) {}
};
/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
/// Context for communication. This provides a low-level interface for forming connections in use-cases
/// where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server
/// connections. Correct use of this class requires external synchronization when finalizing connections with the
/// connect() method.
@@ -619,6 +619,62 @@ class Context : public std::enable_shared_from_this<Context> {
friend class Endpoint;
};
/// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
class SemaphoreStub {
public:
/// Constructor.
/// @param connection A shared pointer to the connection associated with this semaphore.
SemaphoreStub(std::shared_ptr<Connection> connection);
/// Get the memory associated with this semaphore.
/// @return A reference to the registered memory for this semaphore.
const RegisteredMemory& memory() const;
/// Serialize into a vector of characters.
/// @return A vector of characters representing the serialized SemaphoreStub object.
std::vector<char> serialize() const;
/// Deserialize a SemaphoreStub object from a vector of characters.
/// @param data A vector of characters representing a serialized SemaphoreStub object.
/// @return A deserialized SemaphoreStub object.
static SemaphoreStub deserialize(const std::vector<char>& data);
protected:
struct Impl;
SemaphoreStub(std::shared_ptr<Impl> pimpl);
std::shared_ptr<Impl> pimpl_;
friend class Semaphore;
};
/// Semaphore used by channels for synchronization.
class Semaphore {
public:
/// Constructor.
Semaphore() = default;
/// Constructor.
/// @param localStub SemaphoreStub allocated on the local process.
/// @param remoteStub SemaphoreStub allocated on the remote process.
Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub);
/// Get the connection associated with this semaphore.
/// @return A shared pointer to the connection.
std::shared_ptr<Connection> connection() const;
/// Get the local memory associated with this semaphore.
/// @return A reference to the local registered memory.
const RegisteredMemory& localMemory() const;
/// Get the remote memory associated with this semaphore.
/// @return A reference to the remote registered memory.
const RegisteredMemory& remoteMemory() const;
protected:
struct Impl;
std::shared_ptr<Impl> pimpl_;
};
template <typename T>
using NonblockingFuture [[deprecated("Use std::shared_future instead. This will be removed in a future release.")]] =
std::shared_future<T>;
@@ -630,20 +686,26 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
/// 2. Call registerMemory() to register memory regions that will be used for communication.
/// 3. Call sendMemory() or recvMemory() to send/receive registered memory regions to/from
/// other processes.
/// 4. Call get() on all futures returned by connect() and recvMemory().
/// 5. All done; use connections and registered memories to build channels.
/// 4. Call get() on futures returned by connect(). Use the returned connections to create flags.
/// 5. Call buildSemaphore() to create a Semaphore out of the flags.
/// 6. Call get() on all futures returned by buildSemaphore() and recvMemory().
/// 7. All done; use semaphores and registered memories to build channels.
///
/// CAUTION: The order of method calls matters when the same remote rank and tags are used. That is, the i-th
/// "sending" method call (sendMemory(), connect(), and buildSemaphore()) on the local rank must be matched
/// by the i-th "receiving" method call (recvMemory(), connect(), and buildSemaphore()) on the remote rank.
///
/// Correct Example 1:
/// ```cpp
/// // Rank 0
/// communicator.sendMemory(memory1, 1, tag);
/// communicator.sendMemory(memory2, 1, tag);
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
/// connection.get(); // This will return the connection.
/// // Rank 1
/// auto mem1 = communicator.recvMemory(0, tag);
/// auto mem2 = communicator.recvMemory(0, tag);
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
/// mem2.get(); // This will return memory2.
/// connection.get(); // This will return the connection.
/// mem1.get(); // This will return memory1.
@@ -654,13 +716,13 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
/// // Rank 0
/// communicator.sendMemory(memory0, 1, tag);
/// auto mem1 = communicator.recvMemory(1, tag);
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
/// connection.get(); // This will return the connection.
/// mem1.get(); // This will return memory1.
/// // Rank 1
/// auto mem0 = communicator.recvMemory(0, tag);
/// communicator.sendMemory(memory1, 0, tag);
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
/// mem0.get(); // This will return memory0.
/// connection.get(); // This will return the connection.
/// ```
@@ -670,10 +732,10 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
/// // Rank 0
/// communicator.sendMemory(memory0, 1, tag);
/// auto mem1 = communicator.recvMemory(1, tag);
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
/// // Rank 1
/// auto mem0 = communicator.recvMemory(0, tag);
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc); // undefined behavior
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag); // undefined behavior
/// communicator.sendMemory(memory1, 0, tag);
/// ```
/// In the wrong example, the connection information from rank 1 will be sent to the `mem1` object on rank 0,
@@ -691,12 +753,10 @@ class Communicator {
~Communicator();
/// Returns the bootstrap held by this communicator.
///
/// @return The bootstrap held by this communicator.
std::shared_ptr<Bootstrap> bootstrap();
/// Returns the context held by this communicator.
///
/// @return The context held by this communicator.
std::shared_ptr<Context> context();
@@ -723,7 +783,7 @@ class Communicator {
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send.
///
void sendMemory(RegisteredMemory memory, int remoteRank, int tag);
void sendMemory(RegisteredMemory memory, int remoteRank, int tag = 0);
[[deprecated("Use sendMemory() instead. This will be removed in a future release.")]] void sendMemoryOnSetup(
RegisteredMemory memory, int remoteRank, int tag) {
@@ -752,7 +812,7 @@ class Communicator {
/// @param tag The tag to use for identifying the receive.
/// @return A future of registered memory.
///
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag);
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag = 0);
[[deprecated(
"Use recvMemory() instead. This will be removed in a future release.")]] NonblockingFuture<RegisteredMemory>
@@ -762,7 +822,7 @@ class Communicator {
/// Connect to a remote rank.
///
/// This function will start (but not be waiting for) sending metadata about the local endpoint to the remote rank,
/// This function will start (but not wait for) sending metadata about the local endpoint to the remote rank,
/// and return a future connection without waiting for the remote rank to respond.
/// The connection will be established when the remote rank responds with its own endpoint and the local rank calls
/// the first get() on the future.
@@ -783,19 +843,30 @@ class Communicator {
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
/// back to back.
///
/// @param localConfig The configuration for the local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @param localConfig The configuration for the local endpoint.
/// @return A future of shared pointer to the connection.
///
std::shared_future<std::shared_ptr<Connection>> connect(int remoteRank, int tag, EndpointConfig localConfig);
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0);
[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
shared_future<std::shared_ptr<Connection>>
connect(int remoteRank, int tag, EndpointConfig localConfig);
[[deprecated("Use connect() instead. This will be removed in a future release.")]] NonblockingFuture<
std::shared_ptr<Connection>>
connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig) {
return connect(remoteRank, tag, localConfig);
return connect(localConfig, remoteRank, tag);
}
/// Build a semaphore for cross-process synchronization.
/// @param connection The connection associated with this semaphore.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the operation.
/// @return A future of the built semaphore.
std::shared_future<Semaphore> buildSemaphore(std::shared_ptr<Connection> connection, int remoteRank, int tag = 0);
/// Get the remote rank a connection is connected to.
///
/// @param connection The connection to get the remote rank for.
@@ -842,6 +913,10 @@ using PacketPayload = typename T::Payload;
namespace std {
std::string to_string(const mscclpp::Transport& transport);
std::string to_string(const mscclpp::Device& device);
/// Specialization of the std::hash template for mscclpp::TransportFlags.
template <>
struct hash<mscclpp::TransportFlags>;

View File

@@ -13,24 +13,24 @@
/// Throw mscclpp::CudaError if @p cmd does not return cudaSuccess.
/// @param cmd The command to execute.
#define MSCCLPP_CUDATHROW(cmd) \
do { \
cudaError_t err = cmd; \
if (err != cudaSuccess) { \
throw mscclpp::CudaError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
err); \
} \
#define MSCCLPP_CUDATHROW(cmd) \
do { \
cudaError_t err = cmd; \
if (err != cudaSuccess) { \
throw ::mscclpp::CudaError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
err); \
} \
} while (false)
/// Throw mscclpp::CuError if @p cmd does not return CUDA_SUCCESS.
/// @param cmd The command to execute.
#define MSCCLPP_CUTHROW(cmd) \
do { \
CUresult err = cmd; \
if (err != CUDA_SUCCESS) { \
throw mscclpp::CuError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
err); \
} \
#define MSCCLPP_CUTHROW(cmd) \
do { \
CUresult err = cmd; \
if (err != CUDA_SUCCESS) { \
throw ::mscclpp::CuError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
err); \
} \
} while (false)
namespace mscclpp {

View File

@@ -18,13 +18,19 @@ struct BaseMemoryChannel {
std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore_;
public:
/// Default constructor.
/// Constructor.
BaseMemoryChannel() = default;
/// Constructor.
/// @param semaphore The semaphore used to synchronize the communication.
/// @param semaphore Semaphore used to synchronize the communication.
BaseMemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore);
/// Constructor.
/// @param semaphore Semaphore used to synchronize the communication.
BaseMemoryChannel(const Semaphore& semaphore);
/// Constructor.
/// @param other Other BaseMemoryChannel to copy from.
BaseMemoryChannel(const BaseMemoryChannel& other) = default;
BaseMemoryChannel& operator=(BaseMemoryChannel& other) = default;
@@ -33,9 +39,8 @@ struct BaseMemoryChannel {
using DeviceHandle = BaseMemoryChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the BaseMemoryChannel is not released when using the returned handle.
///
/// @return The device-side handle.
DeviceHandle deviceHandle() const;
};
@@ -47,7 +52,7 @@ struct MemoryChannel : public BaseMemoryChannel {
void* packetBuffer_;
public:
/// Default constructor.
/// Constructor.
MemoryChannel() = default;
/// Constructor.
@@ -59,13 +64,20 @@ struct MemoryChannel : public BaseMemoryChannel {
MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore, RegisteredMemory dst, void* src,
void* packetBuffer = nullptr);
/// Constructor.
/// @param semaphore The semaphore used to synchronize the communication.
/// @param dst Registered memory of the destination.
/// @param src The source memory address.
/// @param packetBuffer A buffer used to store packets. @p packetBuffer is optional and if it is nullptr,
/// unpackPacket() and unpackPackets() methods are not available.
MemoryChannel(const Semaphore& semaphore, RegisteredMemory dst, void* src, void* packetBuffer = nullptr);
/// Device-side handle for MemoryChannel.
using DeviceHandle = MemoryChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the MemoryChannel is not released when using the returned handle.
///
/// @return The device-side handle.
DeviceHandle deviceHandle() const;
};

View File

@@ -35,6 +35,11 @@ class ProxyService : public BaseProxyService {
/// @return The ID of the semaphore.
SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
/// Add a semaphore to the proxy service.
/// @param semaphore The semaphore to be added
/// @return The ID of the semaphore.
SemaphoreId addSemaphore(const Semaphore& semaphore);
/// Add a semaphore to the proxy service.
/// @param semaphore The semaphore to be added
/// @return The ID of the semaphore.
@@ -87,22 +92,36 @@ struct BasePortChannel {
std::shared_ptr<Proxy> proxy_;
public:
/// Constructor.
BasePortChannel() = default;
/// Constructor.
/// @param semaphoreId The ID of the semaphore.
/// @param semaphore The semaphore used to synchronize the communication.
/// @param proxy The proxy used for communication.
BasePortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore,
std::shared_ptr<Proxy> proxy);
/// Constructor.
/// @param semaphoreId The ID of the semaphore.
/// @param semaphore The semaphore used to synchronize the communication.
/// @param proxy The proxy used for communication.
BasePortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore, std::shared_ptr<Proxy> proxy);
/// Copy constructor.
/// @param other The other BasePortChannel to copy from.
BasePortChannel(const BasePortChannel& other) = default;
/// Assignment operator.
/// @param other The other BasePortChannel to assign from.
BasePortChannel& operator=(BasePortChannel& other) = default;
/// Device-side handle for BasePortChannel.
using DeviceHandle = BasePortChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the BasePortChannel is not released when using the returned handle.
///
/// @return The device-side handle.
DeviceHandle deviceHandle() const;
};
@@ -113,7 +132,7 @@ struct PortChannel : public BasePortChannel {
MemoryId src_;
public:
/// Default constructor.
/// Constructor.
PortChannel() = default;
/// Constructor.
@@ -125,19 +144,29 @@ struct PortChannel : public BasePortChannel {
PortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore, std::shared_ptr<Proxy> proxy,
MemoryId dst, MemoryId src);
/// Constructor.
/// @param semaphoreId The ID of the semaphore.
/// @param semaphore The semaphore.
/// @param proxy The proxy.
/// @param dst The destination memory region.
/// @param src The source memory region.
PortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore, std::shared_ptr<Proxy> proxy, MemoryId dst,
MemoryId src);
/// Copy constructor.
/// @param other The other PortChannel to copy from.
PortChannel(const PortChannel& other) = default;
/// Assignment operator.
/// @param other The other PortChannel to assign from.
PortChannel& operator=(PortChannel& other) = default;
/// Device-side handle for PortChannel.
using DeviceHandle = PortChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the PortChannel is not released when using the returned handle.
///
/// @return The device-side handle.
DeviceHandle deviceHandle() const;
};

View File

@@ -12,63 +12,18 @@
namespace mscclpp {
/// A base class for semaphores.
///
/// A semaphore is a synchronization mechanism that allows the local peer to wait for the remote peer to complete a
/// data transfer. The local peer signals the remote peer that it has completed a data transfer by incrementing the
/// outbound semaphore ID. The incremented outbound semaphore ID is copied to the remote peer's inbound semaphore ID so
/// that the remote peer can wait for the local peer to complete a data transfer. Vice versa, the remote peer signals
/// the local peer that it has completed a data transfer by incrementing the remote peer's outbound semaphore ID and
/// copying the incremented value to the local peer's inbound semaphore ID.
///
/// @tparam InboundDeleter The deleter for inbound semaphore IDs. This is either `std::default_delete` for host memory
/// or CudaDeleter for device memory.
/// @tparam OutboundDeleter The deleter for outbound semaphore IDs. This is either `std::default_delete` for host memory
/// or CudaDeleter for device memory.
///
template <template <typename> typename InboundDeleter, template <typename> typename OutboundDeleter>
class BaseSemaphore {
protected:
/// The registered memory for the remote peer's inbound semaphore ID.
std::shared_future<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;
/// The inbound semaphore ID that is incremented by the remote peer and waited on by the local peer.
///
/// The location of localInboundSemaphore_ can be either on the host or on the device.
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> localInboundSemaphore_;
/// The expected inbound semaphore ID to be incremented by the local peer and compared to the
/// localInboundSemaphore_.
///
/// The location of expectedInboundSemaphore_ can be either on the host or on the device.
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> expectedInboundSemaphore_;
/// The outbound semaphore ID that is incremented by the local peer and copied to the remote peer's
/// localInboundSemaphore_.
///
/// The location of outboundSemaphore_ can be either on the host or on the device.
std::unique_ptr<uint64_t, OutboundDeleter<uint64_t>> outboundSemaphore_;
public:
/// Constructs a BaseSemaphore.
///
/// @param localInboundSemaphoreId The inbound semaphore ID
/// @param expectedInboundSemaphoreId The expected inbound semaphore ID
/// @param outboundSemaphoreId The outbound semaphore ID
BaseSemaphore(std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> localInboundSemaphoreId,
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> expectedInboundSemaphoreId,
std::unique_ptr<uint64_t, OutboundDeleter<uint64_t>> outboundSemaphoreId)
: localInboundSemaphore_(std::move(localInboundSemaphoreId)),
expectedInboundSemaphore_(std::move(expectedInboundSemaphoreId)),
outboundSemaphore_(std::move(outboundSemaphoreId)) {}
};
/// A semaphore for sending signals from the host to the device.
class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::default_delete> {
class Host2DeviceSemaphore {
private:
std::shared_ptr<Connection> connection_;
Semaphore semaphore_;
detail::UniqueGpuPtr<uint64_t> expectedInboundToken_;
std::unique_ptr<uint64_t> outboundToken_;
public:
/// Constructor.
/// @param semaphore
Host2DeviceSemaphore(const Semaphore& semaphore);
/// Constructor.
/// @param communicator The communicator.
/// @param connection The connection associated with this semaphore.
@@ -76,7 +31,7 @@ class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::defau
/// Returns the connection.
/// @return The connection associated with this semaphore.
std::shared_ptr<Connection> connection();
std::shared_ptr<Connection> connection() const;
/// Signal the device.
void signal();
@@ -85,13 +40,22 @@ class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::defau
using DeviceHandle = Host2DeviceSemaphoreDeviceHandle;
/// Returns the device-side handle.
DeviceHandle deviceHandle();
DeviceHandle deviceHandle() const;
};
/// A semaphore for sending signals from the local host to a remote host.
class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::default_delete> {
class Host2HostSemaphore {
private:
Semaphore semaphore_;
std::unique_ptr<uint64_t> expectedInboundToken_;
std::unique_ptr<uint64_t> outboundToken_;
public:
/// Constructor
/// Constructor.
/// @param semaphore
Host2HostSemaphore(const Semaphore& semaphore);
/// Constructor.
/// @param communicator The communicator.
/// @param connection The connection associated with this semaphore. Transport::CudaIpc is not allowed for
/// Host2HostSemaphore.
@@ -99,7 +63,7 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
/// Returns the connection.
/// @return The connection associated with this semaphore.
std::shared_ptr<Connection> connection();
std::shared_ptr<Connection> connection() const;
/// Signal the remote host.
void signal();
@@ -111,29 +75,34 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
/// Wait for the remote host to signal.
/// @param maxSpinCount The maximum number of spin counts before throwing an exception. Never throws if negative.
void wait(int64_t maxSpinCount = 10000000);
private:
std::shared_ptr<Connection> connection_;
};
/// A semaphore for sending signals from the local device to a peer device via a GPU thread.
class MemoryDevice2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, detail::GpuDeleter> {
class MemoryDevice2DeviceSemaphore {
private:
Semaphore semaphore_;
detail::UniqueGpuPtr<uint64_t> expectedInboundToken_;
detail::UniqueGpuPtr<uint64_t> outboundToken_;
public:
/// Constructor.
/// @param semaphore
MemoryDevice2DeviceSemaphore(const Semaphore& semaphore);
/// Constructor.
/// @param communicator The communicator.
/// @param connection The connection associated with this semaphore.
MemoryDevice2DeviceSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
/// Constructor.
MemoryDevice2DeviceSemaphore() = delete;
/// Returns the connection.
/// @return The connection associated with this semaphore.
std::shared_ptr<Connection> connection() const;
/// Device-side handle for MemoryDevice2DeviceSemaphore.
using DeviceHandle = MemoryDevice2DeviceSemaphoreDeviceHandle;
/// Returns the device-side handle.
DeviceHandle deviceHandle() const;
bool isRemoteInboundSemaphoreIdSet_;
};
/// @deprecated Use MemoryDevice2DeviceSemaphore instead.

View File

@@ -33,28 +33,28 @@ struct Host2DeviceSemaphoreDeviceHandle {
/// Thread-safe read of expected inbound value.
/// @return The expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
return atomicLoad<uint64_t, scopeDevice>(expectedInboundToken, memoryOrderRelaxed);
}
/// Thread-safe increment of expected inbound value.
/// @return The incremented expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundToken, 1, memoryOrderRelaxed) + 1;
}
/// Thread-safe read of inbound value.
/// @return The inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
return atomicLoad<uint64_t, scopeSystem>(inboundToken, memoryOrderAcquire);
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
/// A local memory space where a host thread (on behalf of the remote device) will write its semaphore value
/// and the local device will read it.
uint64_t* inboundSemaphoreId;
uint64_t* inboundToken;
/// A local memory space where the local device stores the expected value of the inboundSemaphoreId to wait for.
uint64_t* expectedInboundSemaphoreId;
/// A local memory space where the local device stores the expected value of the inboundToken to wait for.
uint64_t* expectedInboundToken;
};
/// Device-side handle for MemoryDevice2DeviceSemaphore.
@@ -85,61 +85,61 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle {
auto outbound = incOutbound();
#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ == 800)
// Using memoryOrderSeqCst is faster for A100.
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderSeqCst);
atomicStore(remoteInboundToken, outbound, memoryOrderSeqCst);
#else
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelease);
atomicStore(remoteInboundToken, outbound, memoryOrderRelease);
#endif
}
/// Relaxed signal; no memory completion guarantee. Use it only for synchronizing execution, not data.
MSCCLPP_DEVICE_INLINE void relaxedSignal() {
auto outbound = incOutbound();
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelaxed);
atomicStore(remoteInboundToken, outbound, memoryOrderRelaxed);
}
/// Thread-safe read of expected inbound value.
/// @return The expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
return atomicLoad<uint64_t, scopeDevice>(expectedInboundToken, memoryOrderRelaxed);
}
/// Thread-safe increment of expected inbound value.
/// @return The incremented expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundToken, 1, memoryOrderRelaxed) + 1;
}
/// Thread-safe read of inbound value.
/// @return The inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
return atomicLoad<uint64_t, scopeSystem>(inboundToken, memoryOrderAcquire);
}
/// Thread-safe read of outbound value.
/// @return The outbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadOutbound() {
return atomicLoad<uint64_t, scopeDevice>(outboundSemaphoreId, memoryOrderRelaxed);
return atomicLoad<uint64_t, scopeDevice>(outboundToken, memoryOrderRelaxed);
}
/// Thread-safe increment of outbound value.
/// @return The incremented outbound value.
MSCCLPP_DEVICE_INLINE uint64_t incOutbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(outboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
return atomicFetchAdd<uint64_t, scopeDevice>(outboundToken, 1, memoryOrderRelaxed) + 1;
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
/// A local memory space where the remote device will write its semaphore value and the local device will read it.
uint64_t* inboundSemaphoreId;
uint64_t* inboundToken;
/// A local memory space where the local device stores the semaphore value to be written to the remote device.
uint64_t* outboundSemaphoreId;
uint64_t* outboundToken;
/// A remote memory space where the local device writes its outboundSemaphoreId on. This is inboundSemaphoreId of the
/// A remote memory space where the local device writes its outboundToken on. This is inboundToken of the
/// remote device.
uint64_t* remoteInboundSemaphoreId;
uint64_t* remoteInboundToken;
/// A local memory space where the local device stores the expected value of the inboundSemaphoreId to wait for.
uint64_t* expectedInboundSemaphoreId;
/// A local memory space where the local device stores the expected value of the inboundToken to wait for.
uint64_t* expectedInboundToken;
};
/// @deprecated Use MemoryDevice2DeviceSemaphoreDeviceHandle instead.