Files
mscclpp/include/mscclpp/core.hpp
Binyang Li 5380a4ac6e Add MSCCLPP_IB_GID_INDEX env (#780)
Use MSCCLPP_IB_GID_INDEX to control ib gid index

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-13 09:59:42 -07:00

989 lines
41 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_CORE_HPP_
#define MSCCLPP_CORE_HPP_
#include <array>
#include <bitset>
#include <future>
#include <memory>
#include <mscclpp/errors.hpp>
#include <mscclpp/gpu_data_types.hpp>
#include <mscclpp/version.hpp>
#include <string>
#include <vector>
namespace mscclpp {
constexpr unsigned int UniqueIdBytes = 128;
/// Unique ID for initializing the TcpBootstrap.
using UniqueId = std::array<uint8_t, UniqueIdBytes>;
/// Return a version string.
/// @return The MSCCL++ version string in "major.minor.patch" format.
std::string version();
/// Base class for bootstraps.
class Bootstrap {
public:
/// Constructor.
Bootstrap(){};
/// Destructor.
virtual ~Bootstrap() = default;
/// Return the rank of the process.
/// @return The rank of the process.
virtual int getRank() const = 0;
/// Return the total number of ranks.
/// @return The total number of ranks.
virtual int getNranks() const = 0;
/// Return the total number of ranks per node.
/// @return The total number of ranks per node.
virtual int getNranksPerNode() const = 0;
/// Send arbitrary data to another process.
///
/// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size,
/// senderRank, tag)`. Multiple calls to send() with the same @p peer and @p tag will be ordered by
/// the order of calls, corresponding to the order of recv() calls on the receiving side. In cases where
/// the execution order of multiple send()s or recv()s between two ranks is unknown, they should be differentiated
/// by using different @p tag values to prevent unexpected behavior.
///
/// @param data The data to send.
/// @param size The size of the data to send.
/// @param peer The rank of the process to send the data to.
/// @param tag The tag to send the data with.
virtual void send(void* data, int size, int peer, int tag) = 0;
/// Receive data sent from another process by send().
///
/// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size,
/// senderRank, tag)`. Multiple calls to send() with the same @p peer and @p tag will be ordered by
/// the order of calls, corresponding to the order of recv() calls on the receiving side. In cases where
/// the execution order of multiple send()s or recv()s between two ranks is unknown, they should be differentiated
/// by using different @p tag values to prevent unexpected behavior.
///
/// @param data The buffer to write the received data to.
/// @param size The size of the data to receive.
/// @param peer The rank of the process to receive the data from.
/// @param tag The tag to receive the data with.
virtual void recv(void* data, int size, int peer, int tag) = 0;
/// Gather data from all processes.
///
/// When called by rank `r`, this sends data from `allData[r * size]` to `allData[(r + 1) * size - 1]` to all other
/// ranks. The data sent by rank `r` is received into `allData[r * size]` of other ranks.
///
/// @param allData The buffer to write the received data to.
/// @param size The size of the data each rank sends.
virtual void allGather(void* allData, int size) = 0;
/// Synchronize all processes.
virtual void barrier() = 0;
/// A partial barrier that synchronizes a group of ranks.
/// @param ranks The ranks to synchronize.
void groupBarrier(const std::vector<int>& ranks);
/// Wrapper of send() that sends a vector of characters.
/// @param data The data to send.
/// @param peer The rank of the process to send the data to.
/// @param tag The tag to send the data with.
void send(const std::vector<char>& data, int peer, int tag);
/// Wrapper of recv() that receives a vector of characters.
/// @param data The buffer to write the received data to.
/// @param peer The rank of the process to receive the data from.
/// @param tag The tag to receive the data with.
///
/// @note The data vector will be resized to the size of the received data.
void recv(std::vector<char>& data, int peer, int tag);
};
/// A native implementation of the bootstrap using TCP sockets.
class TcpBootstrap : public Bootstrap {
public:
/// Create a random unique ID.
/// @return The created unique ID.
static UniqueId createUniqueId();
/// Constructor.
/// @param rank The rank of the process.
/// @param nRanks The total number of ranks.
TcpBootstrap(int rank, int nRanks);
/// Destructor.
~TcpBootstrap();
/// Return the unique ID stored in the TcpBootstrap.
/// @return The unique ID stored in the TcpBootstrap.
UniqueId getUniqueId() const;
/// Initialize the TcpBootstrap with a given unique ID. The unique ID can be generated by any method;
/// it can be created by createUniqueId() or can be any arbitrary bit array provided by the user.
/// @param uniqueId The unique ID to initialize the TcpBootstrap with.
/// @param timeoutSec The connection timeout in seconds.
void initialize(UniqueId uniqueId, int64_t timeoutSec = 30);
/// Initialize the TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port".
/// @param ifIpPortTrio The string formatted as "ip:port" or "interface:ip:port".
/// @param timeoutSec The connection timeout in seconds.
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30);
/// Return the rank of the process.
int getRank() const override;
/// Return the total number of ranks.
int getNranks() const override;
/// Return the total number of ranks per node.
int getNranksPerNode() const override;
/// Send arbitrary data to another process.
///
/// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size,
/// senderRank, tag)`. Multiple calls to send() with the same @p peer and @p tag will be ordered by
/// the order of calls, corresponding to the order of recv() calls on the receiving side. In cases where
/// the execution order of multiple send()s or recv()s between two ranks is unknown, they should be differentiated
/// by using different @p tag values to prevent unexpected behavior.
///
/// @param data The data to send.
/// @param size The size of the data to send.
/// @param peer The rank of the process to send the data to.
/// @param tag The tag to send the data with.
void send(void* data, int size, int peer, int tag) override;
/// Receive data sent from another process by send().
///
/// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size,
/// senderRank, tag)`. Multiple calls to send() with the same @p peer and @p tag will be ordered by
/// the order of calls, corresponding to the order of recv() calls on the receiving side. In cases where
/// the execution order of multiple send()s or recv()s between two ranks is unknown, they should be differentiated
/// by using different @p tag values to prevent unexpected behavior.
///
/// @param data The buffer to write the received data to.
/// @param size The size of the data to receive.
/// @param peer The rank of the process to receive the data from.
/// @param tag The tag to receive the data with.
void recv(void* data, int size, int peer, int tag) override;
/// Gather data from all processes.
///
/// When called by rank `r`, this sends data from `allData[r * size]` to `allData[(r + 1) * size - 1]` to all other
/// ranks. The data sent by rank `r` is received into `allData[r * size]` of other ranks.
///
/// @param allData The buffer to write the received data to.
/// @param size The size of the data each rank sends.
void allGather(void* allData, int size) override;
/// Broadcast data from the root process to all processes using a ring-based algorithm.
///
/// When called by the root rank, this sends the `size` bytes starting at memory location `data` to all other
/// ranks. Non-root ranks receive these bytes into their own `data` buffer, overwriting its previous contents.
/// The data propagates sequentially through a logical ring of processes until all ranks have received it.
///
/// @param data Pointer to the send buffer (root) or receive buffer (non-root)
/// @param size Number of bytes to broadcast
/// @param root Rank initiating the broadcast
void broadcast(void* data, int size, int root);
/// Synchronize all processes.
void barrier() override;
private:
class Impl;
std::unique_ptr<Impl> pimpl_;
};
/// Enumerates the available transport types.
enum class Transport {
Unknown, // Unknown transport type.
CudaIpc, // CUDA IPC transport type.
IB0, // InfiniBand device 0 transport type.
IB1, // InfiniBand device 1 transport type.
IB2, // InfiniBand device 2 transport type.
IB3, // InfiniBand device 3 transport type.
IB4, // InfiniBand device 4 transport type.
IB5, // InfiniBand device 5 transport type.
IB6, // InfiniBand device 6 transport type.
IB7, // InfiniBand device 7 transport type.
Ethernet, // Ethernet transport type.
NumTransports, // The number of transports.
};
namespace detail {
const size_t TransportFlagsSize = 11;
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
"TransportFlagsSize must match the number of transports");
/// Bitset for storing transport flags.
using TransportFlagsBase = std::bitset<TransportFlagsSize>;
} // namespace detail
/// Stores transport flags.
class TransportFlags : private detail::TransportFlagsBase {
public:
/// Constructor.
TransportFlags() = default;
/// Constructor.
/// @param transport The transport to set the flag for.
TransportFlags(Transport transport);
/// Check if a specific transport flag is set.
/// @param transport The transport to check the flag for.
/// @return True if the flag is set, false otherwise.
bool has(Transport transport) const;
/// Check if no transport flags are set.
/// @return True if no flags are set, false otherwise.
bool none() const;
/// Check if any transport flags are set.
/// @return True if any flags are set, false otherwise.
bool any() const;
/// Check if all transport flags are set.
/// @return True if all flags are set, false otherwise.
bool all() const;
/// Get the number of transport flags that are set.
/// @return The number of flags that are set.
size_t count() const;
/// Bitwise OR assignment operator for TransportFlags.
/// @param other The other TransportFlags to perform the OR operation with.
/// @return A reference to the modified TransportFlags.
TransportFlags& operator|=(TransportFlags other);
/// Bitwise OR operator for TransportFlags.
/// @param other The other TransportFlags to perform the OR operation with.
/// @return A new TransportFlags object with the result of the OR operation.
TransportFlags operator|(TransportFlags other) const;
/// Bitwise OR operator for TransportFlags and Transport.
/// @param transport The Transport to perform the OR operation with.
/// @return A new TransportFlags object with the result of the OR operation.
TransportFlags operator|(Transport transport) const;
/// Bitwise AND assignment operator for TransportFlags.
/// @param other The other TransportFlags to perform the AND operation with.
/// @return A reference to the modified TransportFlags.
TransportFlags& operator&=(TransportFlags other);
/// Bitwise AND operator for TransportFlags.
/// @param other The other TransportFlags to perform the AND operation with.
/// @return A new TransportFlags object with the result of the AND operation.
TransportFlags operator&(TransportFlags other) const;
/// Bitwise AND operator for TransportFlags and Transport.
/// @param transport The Transport to perform the AND operation with.
/// @return A new TransportFlags object with the result of the AND operation.
TransportFlags operator&(Transport transport) const;
/// Bitwise XOR assignment operator for TransportFlags.
/// @param other The other TransportFlags to perform the XOR operation with.
/// @return A reference to the modified TransportFlags.
TransportFlags& operator^=(TransportFlags other);
/// Bitwise XOR operator for TransportFlags.
/// @param other The other TransportFlags to perform the XOR operation with.
/// @return A new TransportFlags object with the result of the XOR operation.
TransportFlags operator^(TransportFlags other) const;
/// Bitwise XOR operator for TransportFlags and Transport.
/// @param transport The Transport to perform the XOR operation with.
/// @return A new TransportFlags object with the result of the XOR operation.
TransportFlags operator^(Transport transport) const;
/// Bitwise NOT operator for TransportFlags.
/// @return A new TransportFlags object with the result of the NOT operation.
TransportFlags operator~() const;
/// Equality comparison operator for TransportFlags.
/// @param other The other TransportFlags to compare with.
/// @return True if the two TransportFlags objects are equal, false otherwise.
bool operator==(TransportFlags other) const;
/// Inequality comparison operator for TransportFlags.
/// @param other The other TransportFlags to compare with.
/// @return True if the two TransportFlags objects are not equal, false otherwise.
bool operator!=(TransportFlags other) const;
/// Convert the TransportFlags object to a bitset representation.
/// @return A detail::TransportFlagsBase object representing the TransportFlags object.
detail::TransportFlagsBase toBitset() const;
private:
/// Private constructor for TransportFlags that takes a bitset representation.
/// @param bitset The bitset representation of the TransportFlags object.
TransportFlags(detail::TransportFlagsBase bitset);
};
/// Bitwise OR operator for two Transport objects.
///
/// @param transport1 The first Transport to perform the OR operation with.
/// @param transport2 The second Transport to perform the OR operation with.
/// @return A new TransportFlags object with the result of the OR operation.
inline TransportFlags operator|(Transport transport1, Transport transport2) {
return TransportFlags(transport1) | transport2;
}
/// Bitwise AND operator for two Transport objects.
///
/// @param transport1 The first Transport to perform the AND operation with.
/// @param transport2 The second Transport to perform the AND operation with.
/// @return A new TransportFlags object with the result of the AND operation.
inline TransportFlags operator&(Transport transport1, Transport transport2) {
return TransportFlags(transport1) & transport2;
}
/// Bitwise XOR operator for two Transport objects.
///
/// @param transport1 The first Transport to perform the XOR operation with.
/// @param transport2 The second Transport to perform the XOR operation with.
/// @return A new TransportFlags object with the result of the XOR operation.
inline TransportFlags operator^(Transport transport1, Transport transport2) {
return TransportFlags(transport1) ^ transport2;
}
/// Available device types.
enum class DeviceType {
Unknown, // Unknown device type.
CPU, // CPU device type.
GPU, // GPU device type.
};
/// Declaration of a device.
struct Device {
/// Constructor.
Device() = default;
/// Constructor.
/// @param type Device type.
/// @param id Device ID. Default is -1 (no specific ID).
Device(DeviceType type, int id = -1) : type(type), id(id) {}
/// Device Type.
DeviceType type;
/// Device ID.
int id;
};
/// Configuration for creating communication endpoints.
struct EndpointConfig {
/// InfiniBand-specific configuration options that control queue pair behavior and performance characteristics.
/// These settings are only used when the transport is an InfiniBand type (IB0-IB7); they are ignored for other
/// transports.
struct Ib {
/// IB mode for signaling, used to select between different implementations.
enum class Mode {
Default, // Use the MSCCLPP_IBV_MODE environment variable (or "host" if unset).
Host, // Use the host stack with RDMA atomics.
HostNoAtomic // Use the host stack with write-with-immediate signaling (no RDMA atomics).
};
static constexpr int DefaultPort = -1;
static constexpr int DefaultGidIndex = -1;
static constexpr int DefaultMaxCqSize = 1024;
static constexpr int DefaultMaxCqPollNum = 1;
static constexpr int DefaultMaxSendWr = 8192;
static constexpr int DefaultMaxRecvWr = 16;
static constexpr int DefaultMaxWrPerSend = 64;
/// Device index. Currently ignored; use transport type (IB0-IB7) to select device.
int deviceIndex;
/// Port number.
int port;
/// GID index.
int gidIndex;
/// Maximum size of the send completion queue.
int maxCqSize;
/// Maximum number of send completion queue polls per operation.
int maxCqPollNum;
/// Maximum number of outstanding send work requests.
int maxSendWr;
/// Maximum number of outstanding receive work requests (used in HostNoAtomic mode for write-with-immediate).
int maxRecvWr;
/// Maximum number of work requests per send operation.
int maxWrPerSend;
/// IB mode for signaling. When set to Default, uses the MSCCLPP_IBV_MODE environment variable.
Mode mode;
/// Constructor.
/// @param deviceIndex Device index.
/// @param port Port number.
/// @param gidIndex GID index. If -1 (default), uses `MSCCLPP_IB_GID_INDEX` env variable.
/// @param maxCqSize Maximum send completion queue size.
/// @param maxCqPollNum Maximum send completion queue poll count.
/// @param maxSendWr Maximum outstanding send work requests.
/// @param maxRecvWr Maximum outstanding receive work requests (for HostNoAtomic mode).
/// @param maxWrPerSend Maximum work requests per send operation.
/// @param mode IB mode for signaling (Default uses MSCCLPP_IBV_MODE env variable).
Ib(int deviceIndex = -1, int port = DefaultPort, int gidIndex = DefaultGidIndex, int maxCqSize = DefaultMaxCqSize,
int maxCqPollNum = DefaultMaxCqPollNum, int maxSendWr = DefaultMaxSendWr, int maxRecvWr = DefaultMaxRecvWr,
int maxWrPerSend = DefaultMaxWrPerSend, Mode mode = Mode::Default)
: deviceIndex(deviceIndex),
port(port),
gidIndex(gidIndex),
maxCqSize(maxCqSize),
maxCqPollNum(maxCqPollNum),
maxSendWr(maxSendWr),
maxRecvWr(maxRecvWr),
maxWrPerSend(maxWrPerSend),
mode(mode) {}
};
/// Communication transport type (e.g., CudaIpc, IB0-IB7, Ethernet).
Transport transport;
/// Target device for the endpoint (GPU or CPU with optional device ID).
Device device;
/// Maximum number of write requests that can be queued (-1 for default).
int maxWriteQueueSize;
/// InfiniBand-specific options (used only for Transport::IBx).
Ib ib;
/// Constructs endpoint configuration with specified transport, device, and optional settings.
/// @param transport Communication transport to use.
/// @param device Target device for the endpoint.
/// @param maxWriteQueueSize Maximum write queue size (-1 for system default).
/// @param ib IB-specific configuration.
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU, int maxWriteQueueSize = -1,
Ib ib = {})
: transport(transport), device(device), maxWriteQueueSize(maxWriteQueueSize), ib(ib) {}
};
class Context;
class Connection;
class BaseConnection;
class RegisteredMemory;
class SemaphoreStub;
class Semaphore;
/// One end of a connection.
class Endpoint {
public:
/// Constructor.
Endpoint() = default;
/// Get the configuration used to create the endpoint.
/// @return The configuration used to create the endpoint.
const EndpointConfig& config() const;
/// Get the transport used.
/// @return The transport used.
Transport transport() const;
/// Get the device used.
/// @return The device used.
const Device& device() const;
/// Get the host hash.
/// @return The host hash.
uint64_t hostHash() const;
/// Get the process ID hash.
/// @return The process ID hash.
uint64_t pidHash() const;
/// Get the maximum write queue size.
/// @return The maximum number of write requests that can be queued.
int maxWriteQueueSize() const;
/// Serialize the Endpoint object to a vector of characters.
/// @return A vector of characters representing the serialized Endpoint object.
std::vector<char> serialize() const;
/// Deserialize an Endpoint object from a vector of characters.
/// @param data A vector of characters representing a serialized Endpoint object.
/// @return A deserialized Endpoint object.
static Endpoint deserialize(const std::vector<char>& data);
private:
struct Impl;
Endpoint(std::shared_ptr<Impl> pimpl);
std::shared_ptr<Impl> pimpl_;
friend class Context;
friend class BaseConnection;
};
/// Context for communication. This provides a low-level interface for forming connections in use-cases
/// where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server
/// connections. Correct use of this class requires external synchronization when finalizing connections with the
/// connect() method.
///
/// As an example, a client-server scenario where the server will write to the client might proceed as follows:
/// 1. The client creates an endpoint with createEndpoint() and sends it to the server.
/// 2. The server receives the client endpoint, creates its own endpoint with createEndpoint(), sends it to the
/// client, and creates a connection with connect().
/// 3. The client receives the server endpoint, creates a connection with connect() and sends a
/// RegisteredMemory to the server.
/// 4. The server receives the RegisteredMemory and writes to it using the previously created connection.
/// The client waiting to create a connection before sending the RegisteredMemory ensures that the server cannot
/// write to the RegisteredMemory before the connection is established.
///
/// While some transports may have more relaxed implementation behavior, this should not be relied upon.
class Context : public std::enable_shared_from_this<Context> {
public:
/// Create a new Context instance.
static std::shared_ptr<Context> create() { return std::shared_ptr<Context>(new Context()); }
/// Destructor.
~Context();
/// Register a region of GPU memory for use in this context.
///
/// @param ptr Base pointer to the memory.
/// @param size Size of the memory region in bytes.
/// @param transports Transport flags.
/// @return A RegisteredMemory object representing the registered memory region.
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
/// Create an endpoint for establishing connections.
///
/// @param config The configuration for the endpoint.
/// @return The newly created endpoint.
Endpoint createEndpoint(EndpointConfig config);
/// Establish a connection between two endpoints. While this method immediately returns a connection object, the
/// connection is only safe to use after the corresponding connection on the remote endpoint has been established.
/// This method must be called on both endpoints to establish a connection.
///
/// @param localEndpoint The local endpoint.
/// @param remoteEndpoint The remote endpoint.
/// @return A connection object.
Connection connect(const Endpoint& localEndpoint, const Endpoint& remoteEndpoint);
private:
Context();
struct Impl;
std::unique_ptr<Impl> pimpl_;
friend class Endpoint;
friend class BaseConnection;
friend class RegisteredMemory;
friend class SemaphoreStub;
};
/// Block of memory that has been registered to a Context.
/// RegisteredMemory does not own the memory it points to, but it provides a way to transfer metadata about the memory
/// to other processes, hence allowing their access to the memory block.
class RegisteredMemory {
public:
/// Constructor.
RegisteredMemory() = default;
/// Destructor.
~RegisteredMemory();
/// Get a pointer to the memory block.
/// @return A pointer to the memory block.
void* data() const;
/// Get a pointer to the original memory block.
/// @return A pointer to the original memory block.
void* originalDataPtr() const;
/// Get the size of the memory block.
/// @return The size of the memory block.
size_t size() const;
/// Get the transport flags associated with the memory block.
/// @return The transport flags associated with the memory block.
TransportFlags transports() const;
/// Serialize the RegisteredMemory object to a vector of characters.
/// @return A vector of characters representing the serialized RegisteredMemory object.
std::vector<char> serialize() const;
/// Deserialize a RegisteredMemory object from a vector of characters.
/// @param data A vector of characters representing a serialized RegisteredMemory object.
/// @return A deserialized RegisteredMemory object.
static RegisteredMemory deserialize(const std::vector<char>& data);
private:
struct Impl;
RegisteredMemory(std::shared_ptr<Impl> pimpl);
std::shared_ptr<Impl> pimpl_;
friend class Context;
friend class BaseConnection;
friend class SemaphoreStub;
friend class Semaphore;
};
/// Connection between two processes.
class Connection {
public:
/// Constructor.
Connection() = default;
/// Write data from a source RegisteredMemory to a destination RegisteredMemory.
///
/// @param dst The destination RegisteredMemory.
/// @param dstOffset The offset in bytes from the start of the destination RegisteredMemory.
/// @param src The source RegisteredMemory.
/// @param srcOffset The offset in bytes from the start of the source RegisteredMemory.
/// @param size The number of bytes to write.
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
/// Update an 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process.
///
/// @param dst The destination RegisteredMemory.
/// @param dstOffset The offset in bytes from the start of the destination RegisteredMemory.
/// @param src A pointer to the value to update.
/// @param newValue The new value to write.
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue);
/// Flush any pending writes to the remote process.
/// @param timeoutUsec Timeout in microseconds. Default: -1 (no timeout)
void flush(int64_t timeoutUsec = -1);
/// Get the transport used by the local process.
/// @return The transport used by the local process.
Transport transport() const;
/// Get the transport used by the remote process.
/// @return The transport used by the remote process.
Transport remoteTransport() const;
/// Get the context associated with this connection.
/// @return A shared pointer to the context associated with this connection.
std::shared_ptr<Context> context() const;
/// Get the device used by the local endpoint.
/// @return The device used by the local endpoint.
const Device& localDevice() const;
/// Get the maximum write queue size.
/// @return The maximum number of write requests that can be queued.
int getMaxWriteQueueSize() const;
private:
Connection(std::shared_ptr<BaseConnection> impl);
std::shared_ptr<BaseConnection> impl_;
friend class Context;
friend class Communicator;
friend class SemaphoreStub;
friend class Semaphore;
friend class ProxyService;
friend class BaseConnection;
};
/// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
class SemaphoreStub {
public:
/// Constructor.
/// @param connection The connection associated with this semaphore.
SemaphoreStub(const Connection& connection);
/// Get the memory associated with this semaphore.
/// @return A reference to the registered memory for this semaphore.
const RegisteredMemory& memory() const;
/// Serialize into a vector of characters.
/// @return A vector of characters representing the serialized SemaphoreStub object.
std::vector<char> serialize() const;
/// Deserialize a SemaphoreStub object from a vector of characters.
/// @param data A vector of characters representing a serialized SemaphoreStub object.
/// @return A deserialized SemaphoreStub object.
static SemaphoreStub deserialize(const std::vector<char>& data);
protected:
struct Impl;
SemaphoreStub(std::shared_ptr<Impl> pimpl);
std::shared_ptr<Impl> pimpl_;
friend class Semaphore;
};
/// Semaphore used by channels for synchronization.
class Semaphore {
public:
/// Constructor.
Semaphore() = default;
/// Constructor.
/// @param localStub SemaphoreStub allocated on the local process.
/// @param remoteStub SemaphoreStub allocated on the remote process.
Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub);
/// Get the connection associated with this semaphore.
/// @return The connection.
Connection& connection();
/// Get the local memory associated with this semaphore.
/// @return A reference to the local registered memory.
const RegisteredMemory& localMemory() const;
/// Get the remote memory associated with this semaphore.
/// @return A reference to the remote registered memory.
const RegisteredMemory& remoteMemory() const;
protected:
struct Impl;
std::shared_ptr<Impl> pimpl_;
};
/// Deprecated.
template <typename T>
using NonblockingFuture = std::shared_future<T>;
/// A class that sets up all registered memories and connections between processes.
///
/// A typical way to use this class:
/// 1. Call connect() to declare connections between the calling process and other processes.
/// 2. Call registerMemory() to register memory regions that will be used for communication.
/// 3. Call sendMemory() or recvMemory() to send/receive registered memory regions to/from
/// other processes.
/// 4. Call get() on futures returned by connect(). Use the returned connections to create flags.
/// 5. Call buildSemaphore() to create a Semaphore out of the flags.
/// 6. Call get() on all futures returned by buildSemaphore() and recvMemory().
/// 7. All done; use semaphores and registered memories to build channels.
///
/// CAUTION: The order of method calls matters when the same remote rank and tags are used. That is, the i-th
/// "sending" method call (sendMemory(), connect(), and buildSemaphore()) on the local rank must be matched
/// by the i-th "receiving" method call (recvMemory(), connect(), and buildSemaphore()) on the remote rank.
///
/// Correct Example 1:
/// ```cpp
/// // Rank 0
/// communicator.sendMemory(memory1, 1, tag);
/// communicator.sendMemory(memory2, 1, tag);
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
/// connection.get(); // This will return the connection.
/// // Rank 1
/// auto mem1 = communicator.recvMemory(0, tag);
/// auto mem2 = communicator.recvMemory(0, tag);
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
/// mem2.get(); // This will return memory2.
/// connection.get(); // This will return the connection.
/// mem1.get(); // This will return memory1.
/// ```
///
/// Correct Example 2:
/// ```cpp
/// // Rank 0
/// communicator.sendMemory(memory0, 1, tag);
/// auto mem1 = communicator.recvMemory(1, tag);
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
/// connection.get(); // This will return the connection.
/// mem1.get(); // This will return memory1.
/// // Rank 1
/// auto mem0 = communicator.recvMemory(0, tag);
/// communicator.sendMemory(memory1, 0, tag);
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
/// mem0.get(); // This will return memory0.
/// connection.get(); // This will return the connection.
/// ```
///
/// Wrong Example:
/// ```cpp
/// // Rank 0
/// communicator.sendMemory(memory0, 1, tag);
/// auto mem1 = communicator.recvMemory(1, tag);
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
/// // Rank 1
/// auto mem0 = communicator.recvMemory(0, tag);
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag); // undefined behavior
/// communicator.sendMemory(memory1, 0, tag);
/// ```
/// In the wrong example, the connection information from rank 1 will be sent to the `mem1` object on rank 0,
/// where the object type is RegisteredMemory, not Connection.
///
class Communicator {
public:
/// Initializes the communicator with a given bootstrap implementation.
///
/// @param bootstrap An implementation of the Bootstrap that the communicator will use.
/// @param context An optional context to use for the communicator. If not provided, a new context will be created.
Communicator(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context = nullptr);
/// Destroy the communicator.
~Communicator();
/// Returns the bootstrap held by this communicator.
/// @return The bootstrap held by this communicator.
std::shared_ptr<Bootstrap> bootstrap();
/// Returns the context held by this communicator.
/// @return The context held by this communicator.
std::shared_ptr<Context> context();
/// Register a region of GPU memory for use in this communicator's context.
///
/// @param ptr Base pointer to the memory.
/// @param size Size of the memory region in bytes.
/// @param transports Transport flags.
/// @return A RegisteredMemory object representing the registered memory region.
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
/// Send information of a registered memory to the remote side.
///
/// The send will be started upon calling this function, but this function returns immediately without
/// waiting for the completion of the send. RegisteredMemory sent via `sendMemory(memory, remoteRank, tag)` can be
/// received via `recvMemory(remoteRank, tag)`.
///
/// Multiple calls to either sendMemory() or connect() with the same @p remoteRank and @p tag will be ordered by
/// the order of calls, corresponding to the order of recvMemory() or connect() calls on the receiving side.
/// In cases where the execution order is unknown between two ranks, they should be differentiated by using
/// different @p tag values to prevent unexpected behavior.
///
/// @param memory The registered memory buffer to send information about.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send.
///
void sendMemory(RegisteredMemory memory, int remoteRank, int tag = 0);
[[deprecated("Use sendMemory() instead. This will be removed in a future release.")]] void sendMemoryOnSetup(
RegisteredMemory memory, int remoteRank, int tag) {
sendMemory(memory, remoteRank, tag);
}
/// Receive memory information from a corresponding sendMemory call on the remote side.
///
/// This function returns a future immediately. The actual receive will be performed upon calling
/// the first get() on the future. RegisteredMemory sent via `sendMemory(memory, remoteRank, tag)` can be
/// received via `recvMemory(remoteRank, tag)`.
///
/// Multiple calls to either sendMemory() or connect() with the same @p remoteRank and @p tag will be ordered by
/// the order of calls, corresponding to the order of recvMemory() or connect() calls on the receiving side.
/// In cases where the execution order is unknown between two ranks, they should be differentiated by using
/// different @p tag values to prevent unexpected behavior.
///
/// @note To guarantee the receiving order, calling get() on a future returned by recvMemory() or connect()
/// may start receiving other RegisteredMemory or Connection objects of which futures were returned by
/// an earlier call to recvMemory() or connect() with the same @p remoteRank and @p tag. For example, if
/// we call recvMemory() or connect() five times with the same @p remoteRank and @p tag and then call get()
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
/// back to back.
///
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the receive.
/// @return A future of registered memory.
///
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag = 0);
[[deprecated(
"Use recvMemory() instead. This will be removed in a future release.")]] NonblockingFuture<RegisteredMemory>
recvMemoryOnSetup(int remoteRank, int tag) {
return recvMemory(remoteRank, tag);
}
/// Connect to a remote rank.
///
/// This function will start (but not wait for) sending metadata about the local endpoint to the remote rank,
/// and return a future connection without waiting for the remote rank to respond.
/// The connection will be established when the remote rank responds with its own endpoint and the local rank calls
/// the first get() on the future.
/// Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
/// to have a counterpart from rank `j` to rank `i`. Note that with IB, buffers are registered at a page level and if
/// a buffer is spread through multiple pages and does not fully utilize all of them, IB's QP has to register for all
/// involved pages. This potentially has security risks if the connection's accesses are given to a malicious process.
///
/// Multiple calls to either sendMemory() or connect() with the same @p remoteRank and @p tag will be ordered by
/// the order of calls, corresponding to the order of recvMemory() or connect() calls on the receiving side.
/// In cases where the execution order is unknown between two ranks, they should be differentiated by using
/// different @p tag values to prevent unexpected behavior.
///
/// @note To guarantee the receiving order, calling get() on a future returned by recvMemory() or connect()
/// may start receiving other RegisteredMemory or Connection objects of which futures were returned by
/// an earlier call to recvMemory() or connect() with the same @p remoteRank and @p tag. For example, if
/// we call recvMemory() or connect() five times with the same @p remoteRank and @p tag and then call get()
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
/// back to back.
///
/// @param localEndpoint The local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @return A future of the connection.
///
std::shared_future<Connection> connect(const Endpoint& localEndpoint, int remoteRank, int tag = 0);
/// Connect to a remote rank. Wrapper of `connect(localEndpoint, remoteRank, tag)`.
/// @param localConfig The configuration for the local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @return A future of the connection.
std::shared_future<Connection> connect(const EndpointConfig& localConfig, int remoteRank, int tag = 0);
/// Build a semaphore for cross-process synchronization.
/// @param connection The connection associated with this semaphore.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the operation.
/// @return A future of the built semaphore.
std::shared_future<Semaphore> buildSemaphore(const Connection& connection, int remoteRank, int tag = 0);
/// Get the remote rank a connection is connected to.
///
/// @param connection The connection to get the remote rank for.
/// @return The remote rank the connection is connected to.
int remoteRankOf(const Connection& connection);
/// Get the tag a connection was made with.
///
/// @param connection The connection to get the tag for.
/// @return The tag the connection was made with.
int tagOf(const Connection& connection);
[[deprecated("setup() is now no-op and no longer needed. This will be removed in a future release.")]] void setup() {}
private:
struct Impl;
std::unique_ptr<Impl> pimpl_;
};
/// A constant TransportFlags object representing no transports.
extern const TransportFlags NoTransports;
/// A constant TransportFlags object representing all InfiniBand transports.
extern const TransportFlags AllIBTransports;
/// A constant TransportFlags object representing all transports.
extern const TransportFlags AllTransports;
/// A type which could be safely used on the device side.
template <class T>
using DeviceHandle = typename T::DeviceHandle;
/// Retrieve the deviceHandle instance from a host object.
template <typename T>
DeviceHandle<std::remove_reference_t<T>> deviceHandle(T&& t) {
return t.deviceHandle();
}
/// Packet value type.
template <class T>
using PacketPayload = typename T::Payload;
/// Convert Transport to string and output to stream.
/// @param os Output stream.
/// @param transport Input transport.
/// @return Output stream.
std::ostream& operator<<(std::ostream& os, const Transport& transport);
/// Convert DeviceType to string and output to stream.
/// @param os Output stream.
/// @param deviceType Input device type.
/// @return Output stream.
std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType);
/// Convert Device to string and output to stream.
/// @param os Output stream.
/// @param device Input device.
/// @return Output stream.
std::ostream& operator<<(std::ostream& os, const Device& device);
} // namespace mscclpp
#endif // MSCCLPP_CORE_HPP_