New semaphore constructors (#559)

More intuitive interfaces for creating semaphores and channels. Also
allows channel construction using third-party bootstrappers directly
without overriding MSCCL++ Bootstrap.
This commit is contained in:
Changho Hwang
2025-07-11 17:10:46 -07:00
committed by GitHub
parent 20eca28942
commit ae56698d67
42 changed files with 847 additions and 529 deletions

View File

@@ -275,8 +275,8 @@ static std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_pt
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
for (int i = 0; i < comm->bootstrap()->getNranks(); i++) {
if (i == rank) continue;
remoteRegMemoryFutures.push_back(comm->recvMemory(i, 0));
comm->sendMemory(memory, i, 0);
remoteRegMemoryFutures.push_back(comm->recvMemory(i));
comm->sendMemory(memory, i);
}
std::transform(remoteRegMemoryFutures.begin(), remoteRegMemoryFutures.end(), std::back_inserter(remoteMemories),
[](const auto& future) { return future.get(); });
@@ -613,7 +613,7 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
for (int i = 0; i < mscclppComm->bootstrap()->getNranks(); i++) {
if (i == rank) continue;
mscclpp::Transport transport = getTransport(rank, i);
connectionFutures.push_back(mscclppComm->connect(i, 0, transport));
connectionFutures.push_back(mscclppComm->connect(transport, i));
}
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),

View File

@@ -39,16 +39,17 @@ void setupMeshTopology(int rank, int worldsize, void* data, size_t dataSize) {
if (r == rank) continue;
mscclpp::Transport transport = mscclpp::Transport::CudaIpc;
// Connect with all other ranks
connections[r] = comm.connect(r, 0, transport);
connections[r] = comm.connect(transport, r);
auto memory = comm.registerMemory(data, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
localMemories.push_back(memory);
comm.sendMemory(memory, r, 0);
remoteMemories.push_back(comm.recvMemory(r, 0));
comm.sendMemory(memory, r);
remoteMemories.push_back(comm.recvMemory(r));
}
for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get()));
auto sema = communicator->buildSemaphore(connections[r].get(), r).get();
semaphoreIds.push_back(proxyService->addSemaphore(sema));
}
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;

View File

@@ -21,11 +21,11 @@ namespace mscclpp {
#define MSCCLPP_UNIQUE_ID_BYTES 128
/// Unique ID for a process. This is a MSCCLPP_UNIQUE_ID_BYTES byte array that uniquely identifies a process.
/// Unique ID for initializing the TcpBootstrap.
using UniqueId = std::array<uint8_t, MSCCLPP_UNIQUE_ID_BYTES>;
/// Return a version string.
/// @return A string representing the version of MSCCL++ in the format "major.minor.patch".
/// @return The MSCCL++ version string in "major.minor.patch" format.
std::string version();
/// Base class for bootstraps.
@@ -220,9 +220,6 @@ enum class Transport {
NumTransports, // The number of transports.
};
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};
namespace detail {
const size_t TransportFlagsSize = 12;
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
@@ -234,119 +231,99 @@ using TransportFlagsBase = std::bitset<TransportFlagsSize>;
/// Stores transport flags.
class TransportFlags : private detail::TransportFlagsBase {
public:
/// Default constructor for TransportFlags.
/// Constructor.
TransportFlags() = default;
/// Constructor for TransportFlags that takes a Transport enum value.
///
/// Constructor.
/// @param transport The transport to set the flag for.
TransportFlags(Transport transport);
/// Check if a specific transport flag is set.
///
/// @param transport The transport to check the flag for.
/// @return True if the flag is set, false otherwise.
bool has(Transport transport) const;
/// Check if no transport flags are set.
///
/// @return True if no flags are set, false otherwise.
bool none() const;
/// Check if any transport flags are set.
///
/// @return True if any flags are set, false otherwise.
bool any() const;
/// Check if all transport flags are set.
///
/// @return True if all flags are set, false otherwise.
bool all() const;
/// Get the number of transport flags that are set.
///
/// @return The number of flags that are set.
size_t count() const;
/// Bitwise OR assignment operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the OR operation with.
/// @return A reference to the modified TransportFlags.
TransportFlags& operator|=(TransportFlags other);
/// Bitwise OR operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the OR operation with.
/// @return A new TransportFlags object with the result of the OR operation.
TransportFlags operator|(TransportFlags other) const;
/// Bitwise OR operator for TransportFlags and Transport.
///
/// @param transport The Transport to perform the OR operation with.
/// @return A new TransportFlags object with the result of the OR operation.
TransportFlags operator|(Transport transport) const;
/// Bitwise AND assignment operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the AND operation with.
/// @return A reference to the modified TransportFlags.
TransportFlags& operator&=(TransportFlags other);
/// Bitwise AND operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the AND operation with.
/// @return A new TransportFlags object with the result of the AND operation.
TransportFlags operator&(TransportFlags other) const;
/// Bitwise AND operator for TransportFlags and Transport.
///
/// @param transport The Transport to perform the AND operation with.
/// @return A new TransportFlags object with the result of the AND operation.
TransportFlags operator&(Transport transport) const;
/// Bitwise XOR assignment operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the XOR operation with.
/// @return A reference to the modified TransportFlags.
TransportFlags& operator^=(TransportFlags other);
/// Bitwise XOR operator for TransportFlags.
///
/// @param other The other TransportFlags to perform the XOR operation with.
/// @return A new TransportFlags object with the result of the XOR operation.
TransportFlags operator^(TransportFlags other) const;
/// Bitwise XOR operator for TransportFlags and Transport.
///
/// @param transport The Transport to perform the XOR operation with.
/// @return A new TransportFlags object with the result of the XOR operation.
TransportFlags operator^(Transport transport) const;
/// Bitwise NOT operator for TransportFlags.
///
/// @return A new TransportFlags object with the result of the NOT operation.
TransportFlags operator~() const;
/// Equality comparison operator for TransportFlags.
///
/// @param other The other TransportFlags to compare with.
/// @return True if the two TransportFlags objects are equal, false otherwise.
bool operator==(TransportFlags other) const;
/// Inequality comparison operator for TransportFlags.
///
/// @param other The other TransportFlags to compare with.
/// @return True if the two TransportFlags objects are not equal, false otherwise.
bool operator!=(TransportFlags other) const;
/// Convert the TransportFlags object to a bitset representation.
///
/// @return A detail::TransportFlagsBase object representing the TransportFlags object.
detail::TransportFlagsBase toBitset() const;
private:
/// Private constructor for TransportFlags that takes a bitset representation.
///
/// @param bitset The bitset representation of the TransportFlags object.
TransportFlags(detail::TransportFlagsBase bitset);
};
@@ -378,47 +355,64 @@ inline TransportFlags operator^(Transport transport1, Transport transport2) {
return TransportFlags(transport1) ^ transport2;
}
/// Available device types.
enum class DeviceType {
Unknown, // Unknown device type.
CPU, // CPU device type.
GPU, // GPU device type.
};
struct Device {
/// Constructor.
Device() = default;
/// Constructor.
/// @param type Device type.
/// @param id Device ID. Default is -1 (no specific ID).
Device(DeviceType type, int id = -1) : type(type), id(id) {}
/// Device Type.
DeviceType type;
/// Device ID.
int id;
};
class Context;
class Connection;
/// Represents a block of memory that has been registered to a Context.
/// Block of memory that has been registered to a Context.
/// RegisteredMemory does not own the memory it points to, but it provides a way to transfer metadata about the memory
/// to other processes, hence allowing their access to the memory block.
class RegisteredMemory {
public:
/// Default constructor.
/// Constructor.
RegisteredMemory() = default;
/// Destructor.
~RegisteredMemory();
/// Get a pointer to the memory block.
///
/// @return A pointer to the memory block.
void* data() const;
/// Get a pointer to the original memory block.
///
/// @return A pointer to the original memory block.
void* originalDataPtr() const;
/// Get the size of the memory block.
///
/// @return The size of the memory block.
size_t size() const;
/// Get the transport flags associated with the memory block.
///
/// @return The transport flags associated with the memory block.
TransportFlags transports() const;
/// Serialize the RegisteredMemory object to a vector of characters.
///
/// @return A vector of characters representing the serialized RegisteredMemory object.
std::vector<char> serialize() const;
/// Deserialize a RegisteredMemory object from a vector of characters.
///
/// @param data A vector of characters representing a serialized RegisteredMemory object.
/// @return A deserialized RegisteredMemory object.
static RegisteredMemory deserialize(const std::vector<char>& data);
@@ -430,31 +424,32 @@ class RegisteredMemory {
friend class Context;
friend class Connection;
friend class SemaphoreStub;
};
/// Represents one end of a connection.
/// One end of a connection.
class Endpoint {
public:
/// Default constructor.
/// Constructor.
Endpoint() = default;
/// Get the transport used.
///
/// @return The transport used.
Transport transport();
Transport transport() const;
/// Get the device used.
/// @return The device used.
const Device& device() const;
/// Get the maximum write queue size.
///
/// @return The maximum number of write requests that can be queued.
int maxWriteQueueSize();
int maxWriteQueueSize() const;
/// Serialize the Endpoint object to a vector of characters.
///
/// @return A vector of characters representing the serialized Endpoint object.
std::vector<char> serialize();
std::vector<char> serialize() const;
/// Deserialize an Endpoint object from a vector of characters.
///
/// @param data A vector of characters representing a serialized Endpoint object.
/// @return A deserialized Endpoint object.
static Endpoint deserialize(const std::vector<char>& data);
@@ -468,13 +463,13 @@ class Endpoint {
friend class Connection;
};
/// Represents a connection between two processes.
/// Connection between two processes.
class Connection {
public:
/// Constructor.
/// @param maxWriteQueueSize The maximum number of write requests that can be queued.
Connection(std::shared_ptr<Context> context, int maxWriteQueueSize)
: context_(context), maxWriteQueueSize_(maxWriteQueueSize){};
/// @param localEndpoint The local endpoint of the connection.
Connection(std::shared_ptr<Context> context, const Endpoint& localEndpoint)
: context_(context), localEndpoint_(localEndpoint), maxWriteQueueSize_(localEndpoint.maxWriteQueueSize()) {}
/// Destructor.
virtual ~Connection() = default;
@@ -498,34 +493,35 @@ class Connection {
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
/// Flush any pending writes to the remote process.
virtual void flush(int64_t timeoutUsec = 3e7) = 0;
/// @param timeoutUsec Timeout in microseconds. Default: -1 (no timeout)
virtual void flush(int64_t timeoutUsec = -1) = 0;
/// Get the transport used by the local process.
///
/// @return The transport used by the local process.
virtual Transport transport() const = 0;
/// Get the transport used by the remote process.
///
/// @return The transport used by the remote process.
virtual Transport remoteTransport() const = 0;
/// Get the name of the transport used for this connection
///
/// @return A string formatted as "localTransportName -> remoteTransportName".
std::string getTransportName() const;
/// Get the context associated with this connection.
/// @return A shared pointer to the context associated with this connection.
std::shared_ptr<Context> context() const { return context_; }
/// Get the maximum write queue size
///
/// Get the device used by the local endpoint.
/// @return The device used by the local endpoint.
const Device& localDevice() const;
/// Get the maximum write queue size.
/// @return The maximum number of write requests that can be queued.
int getMaxWriteQueueSize() const;
protected:
// Internal methods for getting implementation pointers.
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
std::shared_ptr<Context> context_;
Endpoint localEndpoint_;
int maxWriteQueueSize_;
};
@@ -537,6 +533,7 @@ struct EndpointConfig {
static const int DefaultMaxWrPerSend = 64;
Transport transport;
Device device;
int ibMaxCqSize;
int ibMaxCqPollNum;
int ibMaxSendWr;
@@ -546,15 +543,18 @@ struct EndpointConfig {
/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param transport The transport to use.
/// @param device The device to use.
/// @param ibMaxCqSize The maximum completion queue size.
/// @param ibMaxCqPollNum The maximum completion queue poll number.
/// @param ibMaxSendWr The maximum send work requests.
/// @param ibMaxWrPerSend The maximum work requests per send.
/// @param maxWriteQueueSize The maximum write queue size.
EndpointConfig(Transport transport = Transport::Unknown, int ibMaxCqSize = DefaultMaxCqSize,
int ibMaxCqPollNum = DefaultMaxCqPollNum, int ibMaxSendWr = DefaultMaxSendWr,
int ibMaxWrPerSend = DefaultMaxWrPerSend, int maxWriteQueueSize = -1)
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
int maxWriteQueueSize = -1)
: transport(transport),
device(device),
ibMaxCqSize(ibMaxCqSize),
ibMaxCqPollNum(ibMaxCqPollNum),
ibMaxSendWr(ibMaxSendWr),
@@ -562,7 +562,7 @@ struct EndpointConfig {
maxWriteQueueSize(maxWriteQueueSize) {}
};
/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
/// Context for communication. This provides a low-level interface for forming connections in use-cases
/// where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server
/// connections. Correct use of this class requires external synchronization when finalizing connections with the
/// connect() method.
@@ -619,6 +619,62 @@ class Context : public std::enable_shared_from_this<Context> {
friend class Endpoint;
};
/// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
class SemaphoreStub {
public:
/// Constructor.
/// @param connection A shared pointer to the connection associated with this semaphore.
SemaphoreStub(std::shared_ptr<Connection> connection);
/// Get the memory associated with this semaphore.
/// @return A reference to the registered memory for this semaphore.
const RegisteredMemory& memory() const;
/// Serialize into a vector of characters.
/// @return A vector of characters representing the serialized SemaphoreStub object.
std::vector<char> serialize() const;
/// Deserialize a SemaphoreStub object from a vector of characters.
/// @param data A vector of characters representing a serialized SemaphoreStub object.
/// @return A deserialized SemaphoreStub object.
static SemaphoreStub deserialize(const std::vector<char>& data);
protected:
struct Impl;
SemaphoreStub(std::shared_ptr<Impl> pimpl);
std::shared_ptr<Impl> pimpl_;
friend class Semaphore;
};
/// Semaphore used by channels for synchronization.
class Semaphore {
public:
/// Constructor.
Semaphore() = default;
/// Constructor.
/// @param localStub SemaphoreStub allocated on the local process.
/// @param remoteStub SemaphoreStub allocated on the remote process.
Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub);
/// Get the connection associated with this semaphore.
/// @return A shared pointer to the connection.
std::shared_ptr<Connection> connection() const;
/// Get the local memory associated with this semaphore.
/// @return A reference to the local registered memory.
const RegisteredMemory& localMemory() const;
/// Get the remote memory associated with this semaphore.
/// @return A reference to the remote registered memory.
const RegisteredMemory& remoteMemory() const;
protected:
struct Impl;
std::shared_ptr<Impl> pimpl_;
};
template <typename T>
using NonblockingFuture [[deprecated("Use std::shared_future instead. This will be removed in a future release.")]] =
std::shared_future<T>;
@@ -630,20 +686,26 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
/// 2. Call registerMemory() to register memory regions that will be used for communication.
/// 3. Call sendMemory() or recvMemory() to send/receive registered memory regions to/from
/// other processes.
/// 4. Call get() on all futures returned by connect() and recvMemory().
/// 5. All done; use connections and registered memories to build channels.
/// 4. Call get() on futures returned by connect(). Use the returned connections to create flags.
/// 5. Call buildSemaphore() to create a Semaphore out of the flags.
/// 6. Call get() on all futures returned by buildSemaphore() and recvMemory().
/// 7. All done; use semaphores and registered memories to build channels.
///
/// CAUTION: The order of method calls matters when the same remote rank and tags are used. That is, the i-th
/// "sending" method call (sendMemory(), connect(), and buildSemaphore()) on the local rank must be matched
/// by the i-th "receiving" method call (recvMemory(), connect(), and buildSemaphore()) on the remote rank.
///
/// Correct Example 1:
/// ```cpp
/// // Rank 0
/// communicator.sendMemory(memory1, 1, tag);
/// communicator.sendMemory(memory2, 1, tag);
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
/// connection.get(); // This will return the connection.
/// // Rank 1
/// auto mem1 = communicator.recvMemory(0, tag);
/// auto mem2 = communicator.recvMemory(0, tag);
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
/// mem2.get(); // This will return memory2.
/// connection.get(); // This will return the connection.
/// mem1.get(); // This will return memory1.
@@ -654,13 +716,13 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
/// // Rank 0
/// communicator.sendMemory(memory0, 1, tag);
/// auto mem1 = communicator.recvMemory(1, tag);
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
/// connection.get(); // This will return the connection.
/// mem1.get(); // This will return memory1.
/// // Rank 1
/// auto mem0 = communicator.recvMemory(0, tag);
/// communicator.sendMemory(memory1, 0, tag);
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
/// mem0.get(); // This will return memory0.
/// connection.get(); // This will return the connection.
/// ```
@@ -670,10 +732,10 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
/// // Rank 0
/// communicator.sendMemory(memory0, 1, tag);
/// auto mem1 = communicator.recvMemory(1, tag);
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
/// // Rank 1
/// auto mem0 = communicator.recvMemory(0, tag);
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc); // undefined behavior
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag); // undefined behavior
/// communicator.sendMemory(memory1, 0, tag);
/// ```
/// In the wrong example, the connection information from rank 1 will be sent to the `mem1` object on rank 0,
@@ -691,12 +753,10 @@ class Communicator {
~Communicator();
/// Returns the bootstrap held by this communicator.
///
/// @return The bootstrap held by this communicator.
std::shared_ptr<Bootstrap> bootstrap();
/// Returns the context held by this communicator.
///
/// @return The context held by this communicator.
std::shared_ptr<Context> context();
@@ -723,7 +783,7 @@ class Communicator {
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send.
///
void sendMemory(RegisteredMemory memory, int remoteRank, int tag);
void sendMemory(RegisteredMemory memory, int remoteRank, int tag = 0);
[[deprecated("Use sendMemory() instead. This will be removed in a future release.")]] void sendMemoryOnSetup(
RegisteredMemory memory, int remoteRank, int tag) {
@@ -752,7 +812,7 @@ class Communicator {
/// @param tag The tag to use for identifying the receive.
/// @return A future of registered memory.
///
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag);
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag = 0);
[[deprecated(
"Use recvMemory() instead. This will be removed in a future release.")]] NonblockingFuture<RegisteredMemory>
@@ -762,7 +822,7 @@ class Communicator {
/// Connect to a remote rank.
///
/// This function will start (but not be waiting for) sending metadata about the local endpoint to the remote rank,
/// This function will start (but not wait for) sending metadata about the local endpoint to the remote rank,
/// and return a future connection without waiting for the remote rank to respond.
/// The connection will be established when the remote rank responds with its own endpoint and the local rank calls
/// the first get() on the future.
@@ -783,19 +843,30 @@ class Communicator {
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
/// back to back.
///
/// @param localConfig The configuration for the local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @param localConfig The configuration for the local endpoint.
/// @return A future of shared pointer to the connection.
///
std::shared_future<std::shared_ptr<Connection>> connect(int remoteRank, int tag, EndpointConfig localConfig);
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0);
[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
shared_future<std::shared_ptr<Connection>>
connect(int remoteRank, int tag, EndpointConfig localConfig);
[[deprecated("Use connect() instead. This will be removed in a future release.")]] NonblockingFuture<
std::shared_ptr<Connection>>
connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig) {
return connect(remoteRank, tag, localConfig);
return connect(localConfig, remoteRank, tag);
}
/// Build a semaphore for cross-process synchronization.
/// @param connection The connection associated with this semaphore.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the operation.
/// @return A future of the built semaphore.
std::shared_future<Semaphore> buildSemaphore(std::shared_ptr<Connection> connection, int remoteRank, int tag = 0);
/// Get the remote rank a connection is connected to.
///
/// @param connection The connection to get the remote rank for.
@@ -842,6 +913,10 @@ using PacketPayload = typename T::Payload;
namespace std {
std::string to_string(const mscclpp::Transport& transport);
std::string to_string(const mscclpp::Device& device);
/// Specialization of the std::hash template for mscclpp::TransportFlags.
template <>
struct hash<mscclpp::TransportFlags>;

View File

@@ -13,24 +13,24 @@
/// Throw mscclpp::CudaError if @p cmd does not return cudaSuccess.
/// @param cmd The command to execute.
#define MSCCLPP_CUDATHROW(cmd) \
do { \
cudaError_t err = cmd; \
if (err != cudaSuccess) { \
throw mscclpp::CudaError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
err); \
} \
#define MSCCLPP_CUDATHROW(cmd) \
do { \
cudaError_t err = cmd; \
if (err != cudaSuccess) { \
throw ::mscclpp::CudaError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
err); \
} \
} while (false)
/// Throw mscclpp::CuError if @p cmd does not return CUDA_SUCCESS.
/// @param cmd The command to execute.
#define MSCCLPP_CUTHROW(cmd) \
do { \
CUresult err = cmd; \
if (err != CUDA_SUCCESS) { \
throw mscclpp::CuError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
err); \
} \
#define MSCCLPP_CUTHROW(cmd) \
do { \
CUresult err = cmd; \
if (err != CUDA_SUCCESS) { \
throw ::mscclpp::CuError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
err); \
} \
} while (false)
namespace mscclpp {

View File

@@ -18,13 +18,19 @@ struct BaseMemoryChannel {
std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore_;
public:
/// Default constructor.
/// Constructor.
BaseMemoryChannel() = default;
/// Constructor.
/// @param semaphore The semaphore used to synchronize the communication.
/// @param semaphore Semaphore used to synchronize the communication.
BaseMemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore);
/// Constructor.
/// @param semaphore Semaphore used to synchronize the communication.
BaseMemoryChannel(const Semaphore& semaphore);
/// Constructor.
/// @param other Other BaseMemoryChannel to copy from.
BaseMemoryChannel(const BaseMemoryChannel& other) = default;
BaseMemoryChannel& operator=(BaseMemoryChannel& other) = default;
@@ -33,9 +39,8 @@ struct BaseMemoryChannel {
using DeviceHandle = BaseMemoryChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the BaseMemoryChannel is not released when using the returned handle.
///
/// @return The device-side handle.
DeviceHandle deviceHandle() const;
};
@@ -47,7 +52,7 @@ struct MemoryChannel : public BaseMemoryChannel {
void* packetBuffer_;
public:
/// Default constructor.
/// Constructor.
MemoryChannel() = default;
/// Constructor.
@@ -59,13 +64,20 @@ struct MemoryChannel : public BaseMemoryChannel {
MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore, RegisteredMemory dst, void* src,
void* packetBuffer = nullptr);
/// Constructor.
/// @param semaphore The semaphore used to synchronize the communication.
/// @param dst Registered memory of the destination.
/// @param src The source memory address.
/// @param packetBuffer A buffer used to store packets. @p packetBuffer is optional and if it is nullptr,
/// unpackPacket() and unpackPackets() methods are not available.
MemoryChannel(const Semaphore& semaphore, RegisteredMemory dst, void* src, void* packetBuffer = nullptr);
/// Device-side handle for MemoryChannel.
using DeviceHandle = MemoryChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the MemoryChannel is not released when using the returned handle.
///
/// @return The device-side handle.
DeviceHandle deviceHandle() const;
};

View File

@@ -35,6 +35,11 @@ class ProxyService : public BaseProxyService {
/// @return The ID of the semaphore.
SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
/// Add a semaphore to the proxy service.
/// @param semaphore The semaphore to be added
/// @return The ID of the semaphore.
SemaphoreId addSemaphore(const Semaphore& semaphore);
/// Add a semaphore to the proxy service.
/// @param semaphore The semaphore to be added
/// @return The ID of the semaphore.
@@ -87,22 +92,36 @@ struct BasePortChannel {
std::shared_ptr<Proxy> proxy_;
public:
/// Constructor.
BasePortChannel() = default;
/// Constructor.
/// @param semaphoreId The ID of the semaphore.
/// @param semaphore The semaphore used to synchronize the communication.
/// @param proxy The proxy used for communication.
BasePortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore,
std::shared_ptr<Proxy> proxy);
/// Constructor.
/// @param semaphoreId The ID of the semaphore.
/// @param semaphore The semaphore used to synchronize the communication.
/// @param proxy The proxy used for communication.
BasePortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore, std::shared_ptr<Proxy> proxy);
/// Copy constructor.
/// @param other The other BasePortChannel to copy from.
BasePortChannel(const BasePortChannel& other) = default;
/// Assignment operator.
/// @param other The other BasePortChannel to assign from.
BasePortChannel& operator=(BasePortChannel& other) = default;
/// Device-side handle for BasePortChannel.
using DeviceHandle = BasePortChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the BasePortChannel is not released when using the returned handle.
///
/// @return The device-side handle.
DeviceHandle deviceHandle() const;
};
@@ -113,7 +132,7 @@ struct PortChannel : public BasePortChannel {
MemoryId src_;
public:
/// Default constructor.
/// Constructor.
PortChannel() = default;
/// Constructor.
@@ -125,19 +144,29 @@ struct PortChannel : public BasePortChannel {
PortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore, std::shared_ptr<Proxy> proxy,
MemoryId dst, MemoryId src);
/// Constructor.
/// @param semaphoreId The ID of the semaphore.
/// @param semaphore The semaphore.
/// @param proxy The proxy.
/// @param dst The destination memory region.
/// @param src The source memory region.
PortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore, std::shared_ptr<Proxy> proxy, MemoryId dst,
MemoryId src);
/// Copy constructor.
/// @param other The other PortChannel to copy from.
PortChannel(const PortChannel& other) = default;
/// Assignment operator.
/// @param other The other PortChannel to assign from.
PortChannel& operator=(PortChannel& other) = default;
/// Device-side handle for PortChannel.
using DeviceHandle = PortChannelDeviceHandle;
/// Returns the device-side handle.
///
/// User should make sure the PortChannel is not released when using the returned handle.
///
/// @return The device-side handle.
DeviceHandle deviceHandle() const;
};

View File

@@ -12,63 +12,18 @@
namespace mscclpp {
/// A base class for semaphores.
///
/// A semaphore is a synchronization mechanism that allows the local peer to wait for the remote peer to complete a
/// data transfer. The local peer signals the remote peer that it has completed a data transfer by incrementing the
/// outbound semaphore ID. The incremented outbound semaphore ID is copied to the remote peer's inbound semaphore ID so
/// that the remote peer can wait for the local peer to complete a data transfer. Vice versa, the remote peer signals
/// the local peer that it has completed a data transfer by incrementing the remote peer's outbound semaphore ID and
/// copying the incremented value to the local peer's inbound semaphore ID.
///
/// @tparam InboundDeleter The deleter for inbound semaphore IDs. This is either `std::default_delete` for host memory
/// or CudaDeleter for device memory.
/// @tparam OutboundDeleter The deleter for outbound semaphore IDs. This is either `std::default_delete` for host memory
/// or CudaDeleter for device memory.
///
template <template <typename> typename InboundDeleter, template <typename> typename OutboundDeleter>
class BaseSemaphore {
protected:
/// The registered memory for the remote peer's inbound semaphore ID.
std::shared_future<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;
/// The inbound semaphore ID that is incremented by the remote peer and waited on by the local peer.
///
/// The location of localInboundSemaphore_ can be either on the host or on the device.
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> localInboundSemaphore_;
/// The expected inbound semaphore ID to be incremented by the local peer and compared to the
/// localInboundSemaphore_.
///
/// The location of expectedInboundSemaphore_ can be either on the host or on the device.
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> expectedInboundSemaphore_;
/// The outbound semaphore ID that is incremented by the local peer and copied to the remote peer's
/// localInboundSemaphore_.
///
/// The location of outboundSemaphore_ can be either on the host or on the device.
std::unique_ptr<uint64_t, OutboundDeleter<uint64_t>> outboundSemaphore_;
public:
/// Constructs a BaseSemaphore.
///
/// @param localInboundSemaphoreId The inbound semaphore ID
/// @param expectedInboundSemaphoreId The expected inbound semaphore ID
/// @param outboundSemaphoreId The outbound semaphore ID
BaseSemaphore(std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> localInboundSemaphoreId,
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> expectedInboundSemaphoreId,
std::unique_ptr<uint64_t, OutboundDeleter<uint64_t>> outboundSemaphoreId)
: localInboundSemaphore_(std::move(localInboundSemaphoreId)),
expectedInboundSemaphore_(std::move(expectedInboundSemaphoreId)),
outboundSemaphore_(std::move(outboundSemaphoreId)) {}
};
/// A semaphore for sending signals from the host to the device.
class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::default_delete> {
class Host2DeviceSemaphore {
private:
std::shared_ptr<Connection> connection_;
Semaphore semaphore_;
detail::UniqueGpuPtr<uint64_t> expectedInboundToken_;
std::unique_ptr<uint64_t> outboundToken_;
public:
/// Constructor.
/// @param semaphore
Host2DeviceSemaphore(const Semaphore& semaphore);
/// Constructor.
/// @param communicator The communicator.
/// @param connection The connection associated with this semaphore.
@@ -76,7 +31,7 @@ class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::defau
/// Returns the connection.
/// @return The connection associated with this semaphore.
std::shared_ptr<Connection> connection();
std::shared_ptr<Connection> connection() const;
/// Signal the device.
void signal();
@@ -85,13 +40,22 @@ class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::defau
using DeviceHandle = Host2DeviceSemaphoreDeviceHandle;
/// Returns the device-side handle.
DeviceHandle deviceHandle();
DeviceHandle deviceHandle() const;
};
/// A semaphore for sending signals from the local host to a remote host.
class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::default_delete> {
class Host2HostSemaphore {
private:
Semaphore semaphore_;
std::unique_ptr<uint64_t> expectedInboundToken_;
std::unique_ptr<uint64_t> outboundToken_;
public:
/// Constructor
/// Constructor.
/// @param semaphore
Host2HostSemaphore(const Semaphore& semaphore);
/// Constructor.
/// @param communicator The communicator.
/// @param connection The connection associated with this semaphore. Transport::CudaIpc is not allowed for
/// Host2HostSemaphore.
@@ -99,7 +63,7 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
/// Returns the connection.
/// @return The connection associated with this semaphore.
std::shared_ptr<Connection> connection();
std::shared_ptr<Connection> connection() const;
/// Signal the remote host.
void signal();
@@ -111,29 +75,34 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
/// Wait for the remote host to signal.
/// @param maxSpinCount The maximum number of spin counts before throwing an exception. Never throws if negative.
void wait(int64_t maxSpinCount = 10000000);
private:
std::shared_ptr<Connection> connection_;
};
/// A semaphore for sending signals from the local device to a peer device via a GPU thread.
class MemoryDevice2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, detail::GpuDeleter> {
class MemoryDevice2DeviceSemaphore {
private:
Semaphore semaphore_;
detail::UniqueGpuPtr<uint64_t> expectedInboundToken_;
detail::UniqueGpuPtr<uint64_t> outboundToken_;
public:
/// Constructor.
/// @param semaphore
MemoryDevice2DeviceSemaphore(const Semaphore& semaphore);
/// Constructor.
/// @param communicator The communicator.
/// @param connection The connection associated with this semaphore.
MemoryDevice2DeviceSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
/// Constructor.
MemoryDevice2DeviceSemaphore() = delete;
/// Returns the connection.
/// @return The connection associated with this semaphore.
std::shared_ptr<Connection> connection() const;
/// Device-side handle for MemoryDevice2DeviceSemaphore.
using DeviceHandle = MemoryDevice2DeviceSemaphoreDeviceHandle;
/// Returns the device-side handle.
DeviceHandle deviceHandle() const;
bool isRemoteInboundSemaphoreIdSet_;
};
/// @deprecated Use MemoryDevice2DeviceSemaphore instead.

View File

@@ -33,28 +33,28 @@ struct Host2DeviceSemaphoreDeviceHandle {
/// Thread-safe read of expected inbound value.
/// @return The expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
return atomicLoad<uint64_t, scopeDevice>(expectedInboundToken, memoryOrderRelaxed);
}
/// Thread-safe increment of expected inbound value.
/// @return The incremented expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundToken, 1, memoryOrderRelaxed) + 1;
}
/// Thread-safe read of inbound value.
/// @return The inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
return atomicLoad<uint64_t, scopeSystem>(inboundToken, memoryOrderAcquire);
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
/// A local memory space where a host thread (on behalf of the remote device) will write its semaphore value
/// and the local device will read it.
uint64_t* inboundSemaphoreId;
uint64_t* inboundToken;
/// A local memory space where the local device stores the expected value of the inboundSemaphoreId to wait for.
uint64_t* expectedInboundSemaphoreId;
/// A local memory space where the local device stores the expected value of the inboundToken to wait for.
uint64_t* expectedInboundToken;
};
/// Device-side handle for MemoryDevice2DeviceSemaphore.
@@ -85,61 +85,61 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle {
auto outbound = incOutbound();
#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ == 800)
// Using memoryOrderSeqCst is faster for A100.
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderSeqCst);
atomicStore(remoteInboundToken, outbound, memoryOrderSeqCst);
#else
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelease);
atomicStore(remoteInboundToken, outbound, memoryOrderRelease);
#endif
}
/// Relaxed signal; no memory completion guarantee. Use it only for synchronizing execution, not data.
MSCCLPP_DEVICE_INLINE void relaxedSignal() {
auto outbound = incOutbound();
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelaxed);
atomicStore(remoteInboundToken, outbound, memoryOrderRelaxed);
}
/// Thread-safe read of expected inbound value.
/// @return The expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
return atomicLoad<uint64_t, scopeDevice>(expectedInboundToken, memoryOrderRelaxed);
}
/// Thread-safe increment of expected inbound value.
/// @return The incremented expected inbound value.
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundToken, 1, memoryOrderRelaxed) + 1;
}
/// Thread-safe read of inbound value.
/// @return The inbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
return atomicLoad<uint64_t, scopeSystem>(inboundToken, memoryOrderAcquire);
}
/// Thread-safe read of outbound value.
/// @return The outbound value.
MSCCLPP_DEVICE_INLINE uint64_t loadOutbound() {
return atomicLoad<uint64_t, scopeDevice>(outboundSemaphoreId, memoryOrderRelaxed);
return atomicLoad<uint64_t, scopeDevice>(outboundToken, memoryOrderRelaxed);
}
/// Thread-safe increment of outbound value.
/// @return The incremented outbound value.
MSCCLPP_DEVICE_INLINE uint64_t incOutbound() {
return atomicFetchAdd<uint64_t, scopeDevice>(outboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
return atomicFetchAdd<uint64_t, scopeDevice>(outboundToken, 1, memoryOrderRelaxed) + 1;
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)
/// A local memory space where the remote device will write its semaphore value and the local device will read it.
uint64_t* inboundSemaphoreId;
uint64_t* inboundToken;
/// A local memory space where the local device stores the semaphore value to be written to the remote device.
uint64_t* outboundSemaphoreId;
uint64_t* outboundToken;
/// A remote memory space where the local device writes its outboundSemaphoreId on. This is inboundSemaphoreId of the
/// A remote memory space where the local device writes its outboundToken on. This is inboundToken of the
/// remote device.
uint64_t* remoteInboundSemaphoreId;
uint64_t* remoteInboundToken;
/// A local memory space where the local device stores the expected value of the inboundSemaphoreId to wait for.
uint64_t* expectedInboundSemaphoreId;
/// A local memory space where the local device stores the expected value of the inboundToken to wait for.
uint64_t* expectedInboundToken;
};
/// @deprecated Use MemoryDevice2DeviceSemaphoreDeviceHandle instead.

View File

@@ -14,6 +14,8 @@ from ._mscclpp import (
CudaError,
CuError,
IbError,
Device,
DeviceType,
Communicator,
Connection,
connect_nvls_collective,
@@ -43,6 +45,8 @@ from ._mscclpp import (
__all__ = [
"Device",
"DeviceType",
"Communicator",
"Connection",
"connect_nvls_collective",

View File

@@ -101,7 +101,7 @@ class CommGroup:
if endpoint.transport == Transport.Nvls:
return connect_nvls_collective(self.communicator, all_ranks, 2**30)
else:
connections[rank] = self.communicator.connect(rank, 0, endpoint)
connections[rank] = self.communicator.connect(endpoint, rank)
connections = {rank: connections[rank].get() for rank in connections}
return connections
@@ -124,8 +124,8 @@ class CommGroup:
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
for rank in connections:
self.communicator.send_memory(local_reg_memory, rank, 0)
future_memories[rank] = self.communicator.recv_memory(rank, 0)
self.communicator.send_memory(local_reg_memory, rank)
future_memories[rank] = self.communicator.recv_memory(rank)
for rank in connections:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories

View File

@@ -112,6 +112,19 @@ void register_core(nb::module_& m) {
.def(nb::self == nb::self)
.def(nb::self != nb::self);
nb::enum_<DeviceType>(m, "DeviceType")
.value("Unknown", DeviceType::Unknown)
.value("CPU", DeviceType::CPU)
.value("GPU", DeviceType::GPU);
nb::class_<Device>(m, "Device")
.def(nb::init<>())
.def(nb::init_implicit<DeviceType>(), nb::arg("type"))
.def(nb::init<DeviceType, int>(), nb::arg("type"), nb::arg("id") = -1)
.def_rw("type", &Device::type)
.def_rw("id", &Device::id)
.def("__str__", [](const Device& self) { return std::to_string(self); });
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
.def(nb::init<>())
.def("data", &RegisteredMemory::data)
@@ -120,6 +133,13 @@ void register_core(nb::module_& m) {
.def("serialize", &RegisteredMemory::serialize)
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
nb::class_<Endpoint>(m, "Endpoint")
.def("transport", &Endpoint::transport)
.def("device", &Endpoint::device)
.def("max_write_queue_size", &Endpoint::maxWriteQueueSize)
.def("serialize", &Endpoint::serialize)
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
nb::class_<Connection>(m, "Connection")
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
nb::arg("size"))
@@ -131,21 +151,26 @@ void register_core(nb::module_& m) {
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
.def("flush", &Connection::flush, nb::call_guard<nb::gil_scoped_release>(), nb::arg("timeoutUsec") = (int64_t)3e7)
.def("transport", &Connection::transport)
.def("remote_transport", &Connection::remoteTransport);
nb::class_<Endpoint>(m, "Endpoint")
.def("transport", &Endpoint::transport)
.def("serialize", &Endpoint::serialize)
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
.def("remote_transport", &Connection::remoteTransport)
.def("context", &Connection::context)
.def("local_device", &Connection::localDevice)
.def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize);
nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, Device, int, int, int, int, int>(), nb::arg("transport"), nb::arg("device"),
nb::arg("ibMaxCqSize") = EndpointConfig::DefaultMaxCqSize,
nb::arg("ibMaxCqPollNum") = EndpointConfig::DefaultMaxCqPollNum,
nb::arg("ibMaxSendWr") = EndpointConfig::DefaultMaxSendWr,
nb::arg("ibMaxWrPerSend") = EndpointConfig::DefaultMaxWrPerSend, nb::arg("maxWriteQueueSize") = -1)
.def_rw("transport", &EndpointConfig::transport)
.def_rw("device", &EndpointConfig::device)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend);
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend)
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
nb::class_<Context>(m, "Context")
.def_static("create", &Context::create)
@@ -158,6 +183,19 @@ void register_core(nb::module_& m) {
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
.def(nb::init<std::shared_ptr<Connection>>(), nb::arg("connection"))
.def("memory", &SemaphoreStub::memory)
.def("serialize", &SemaphoreStub::serialize)
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
nb::class_<Semaphore>(m, "Semaphore")
.def(nb::init<>())
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("localStub"), nb::arg("remoteStub"))
.def("connection", &Semaphore::connection)
.def("local_memory", &Semaphore::localMemory)
.def("remote_memory", &Semaphore::remoteMemory);
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
def_shared_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
@@ -172,12 +210,28 @@ void register_core(nb::module_& m) {
return self->registerMemory((void*)ptr, size, transports);
},
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("connect",
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
&Communicator::connect),
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
.def(
"connect",
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
return self->connect(std::move(localConfig), remoteRank, tag);
},
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def(
"connect_on_setup",
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
return self->connect(std::move(localConfig), remoteRank, tag);
},
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("localFlag"), nb::arg("remoteRank"),
nb::arg("tag") = 0)
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", [](Communicator*) {});

View File

@@ -11,12 +11,10 @@ namespace nb = nanobind;
using namespace mscclpp;
void register_memory_channel(nb::module_& m) {
nb::class_<BaseMemoryChannel> baseMemoryChannel(m, "BaseMemoryChannel");
baseMemoryChannel
.def("__init__",
[](BaseMemoryChannel* baseMemoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore) {
new (baseMemoryChannel) BaseMemoryChannel(semaphore);
})
nb::class_<BaseMemoryChannel>(m, "BaseMemoryChannel")
.def(nb::init<>())
.def(nb::init<std::shared_ptr<MemoryDevice2DeviceSemaphore>>(), nb::arg("semaphore"))
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def("device_handle", &BaseMemoryChannel::deviceHandle);
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "BaseMemoryChannelDeviceHandle")
@@ -26,8 +24,8 @@ void register_memory_channel(nb::module_& m) {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<MemoryChannel> memoryChannel(m, "MemoryChannel");
memoryChannel
nb::class_<MemoryChannel>(m, "MemoryChannel")
.def(nb::init<>())
.def("__init__",
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst,

View File

@@ -20,13 +20,19 @@ void register_port_channel(nb::module_& m) {
.def("start_proxy", &ProxyService::startProxy)
.def("stop_proxy", &ProxyService::stopProxy)
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
.def("add_semaphore", static_cast<SemaphoreId (ProxyService::*)(const Semaphore&)>(&ProxyService::addSemaphore),
nb::arg("semaphore"))
.def("add_semaphore",
static_cast<SemaphoreId (ProxyService::*)(std::shared_ptr<Host2DeviceSemaphore>)>(
&ProxyService::addSemaphore),
nb::arg("semaphore"))
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
.def("base_port_channel", &ProxyService::basePortChannel, nb::arg("id"))
.def("port_channel", &ProxyService::portChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
nb::class_<BasePortChannel>(m, "BasePortChannel")
.def(nb::init<>())
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
.def("device_handle", &BasePortChannel::deviceHandle);
@@ -41,6 +47,7 @@ void register_port_channel(nb::module_& m) {
});
nb::class_<PortChannel>(m, "PortChannel")
.def(nb::init<>())
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
.def("device_handle", &PortChannel::deviceHandle);

View File

@@ -11,7 +11,7 @@ using namespace mscclpp;
void register_semaphore(nb::module_& m) {
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
host2DeviceSemaphore
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &Host2DeviceSemaphore::connection)
.def("signal", &Host2DeviceSemaphore::signal)
@@ -19,13 +19,14 @@ void register_semaphore(nb::module_& m) {
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
.def(nb::init<>())
.def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
.def_rw("inbound_token", &Host2DeviceSemaphore::DeviceHandle::inboundToken)
.def_rw("expected_inbound_token", &Host2DeviceSemaphore::DeviceHandle::expectedInboundToken)
.def_prop_ro("raw", [](const Host2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &Host2HostSemaphore::connection)
.def("signal", &Host2HostSemaphore::signal)
@@ -34,16 +35,17 @@ void register_semaphore(nb::module_& m) {
nb::arg("max_spin_count") = 10000000);
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
memoryDevice2DeviceSemaphore
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
.def("device_handle", &MemoryDevice2DeviceSemaphore::deviceHandle);
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
.def(nb::init<>())
.def_rw("inboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
.def_rw("outboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
.def_rw("remoteInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
.def_rw("expectedInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
.def_rw("inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundToken)
.def_rw("outbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundToken)
.def_rw("remote_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundToken)
.def_rw("expected_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundToken)
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});

View File

@@ -28,6 +28,8 @@ from mscclpp import (
is_nvls_supported,
npkit,
env,
Device,
DeviceType,
)
import mscclpp.comm as mscclpp_comm
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
@@ -280,7 +282,13 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
@parametrize_mpi_groups(2, 4, 8, 16)
def test_h2h_semaphores(mpi_group: MpiGroup):
group, connections = create_group_and_connection(mpi_group, "IB")
group = mscclpp_comm.CommGroup(mpi_group.comm)
tran = group.my_ib_device(group.my_rank % 8)
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
remote_nghrs = list(range(group.nranks))
remote_nghrs.remove(group.my_rank)
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
connections = {rank: conn.get() for rank, conn in connections.items()}
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
for rank in connections:
@@ -293,7 +301,13 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
@parametrize_mpi_groups(2, 4, 8, 16)
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
group, connections = create_group_and_connection(mpi_group, "IB")
group = mscclpp_comm.CommGroup(mpi_group.comm)
tran = group.my_ib_device(group.my_rank % 8)
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
remote_nghrs = list(range(group.nranks))
remote_nghrs.remove(group.my_rank)
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
connections = {rank: conn.get() for rank, conn in connections.items()}
semaphores = group.make_semaphore(connections, Host2HostSemaphore)

View File

@@ -50,7 +50,7 @@ const char* SocketToString(union SocketAddress* addr, char* buf, const int numer
static int getTcpFinTimeout() {
std::ifstream ifs("/proc/sys/net/ipv4/tcp_fin_timeout");
if (!ifs.is_open()) {
throw mscclpp::SysError("open /proc/sys/net/ipv4/tcp_fin_timeout failed", errno);
throw SysError("open /proc/sys/net/ipv4/tcp_fin_timeout failed", errno);
}
int timeout;
ifs >> timeout;
@@ -80,12 +80,12 @@ static int findInterfaces(const char* prefixList, char* names, union SocketAddre
#ifdef MSCCLPP_ENABLE_TRACE
char line[SOCKET_NAME_MAXLEN + 1];
#endif
struct mscclpp::netIf userIfs[MAX_IFS];
struct netIf userIfs[MAX_IFS];
bool searchNot = prefixList && prefixList[0] == '^';
if (searchNot) prefixList++;
bool searchExact = prefixList && prefixList[0] == '=';
if (searchExact) prefixList++;
int nUserIfs = mscclpp::parseStringList(prefixList, userIfs, MAX_IFS);
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
int found = 0;
struct ifaddrs *interfaces, *interface;
@@ -110,7 +110,7 @@ static int findInterfaces(const char* prefixList, char* names, union SocketAddre
}
// check against user specified interfaces
if (!(mscclpp::matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
continue;
}
@@ -224,16 +224,16 @@ int FindInterfaceMatchSubnet(char* ifNames, union SocketAddress* localAddrs, uni
void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair) {
if (!(ip_port_pair && strlen(ip_port_pair) > 1)) {
throw mscclpp::Error("Net : string is null", mscclpp::ErrorCode::InvalidUsage);
throw Error("Net : string is null", ErrorCode::InvalidUsage);
}
bool ipv6 = ip_port_pair[0] == '[';
/* Construct the sockaddress structure */
if (!ipv6) {
struct mscclpp::netIf ni;
struct netIf ni;
// parse <ip_or_hostname>:<port> string, expect one pair
if (mscclpp::parseStringList(ip_port_pair, &ni, 1) != 1) {
throw mscclpp::Error("Net : No valid <IPv4_or_hostname>:<port> pair found", mscclpp::ErrorCode::InvalidUsage);
if (parseStringList(ip_port_pair, &ni, 1) != 1) {
throw Error("Net : No valid <IPv4_or_hostname>:<port> pair found", ErrorCode::InvalidUsage);
}
struct addrinfo hints, *p;
@@ -245,7 +245,7 @@ void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair)
if ((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
std::stringstream ss;
ss << "Net : error encountered when getting address info : " << gai_strerror(rv);
throw mscclpp::Error(ss.str(), mscclpp::ErrorCode::InvalidUsage);
throw Error(ss.str(), ErrorCode::InvalidUsage);
}
// use the first
@@ -263,7 +263,7 @@ void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair)
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = 0; // should be global scope, set to 0
} else {
throw mscclpp::Error("Net : unsupported IP family", mscclpp::ErrorCode::InvalidUsage);
throw Error("Net : unsupported IP family", ErrorCode::InvalidUsage);
}
freeaddrinfo(p); // all done with this structure
@@ -276,7 +276,7 @@ void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair)
}
if (i == len) {
WARN("Net : No valid [IPv6]:port pair found");
throw mscclpp::Error("Net : No valid [IPv6]:port pair found", mscclpp::ErrorCode::InvalidUsage);
throw Error("Net : No valid [IPv6]:port pair found", ErrorCode::InvalidUsage);
}
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
@@ -448,7 +448,7 @@ void Socket::bindAndListen() {
}
void Socket::connect(int64_t timeout) {
mscclpp::Timer timer;
Timer timer;
#ifdef MSCCLPP_ENABLE_TRACE
char line[SOCKET_NAME_MAXLEN + 1];
#endif
@@ -483,7 +483,7 @@ void Socket::connect(int64_t timeout) {
}
void Socket::accept(const Socket* listenSocket, int64_t timeout) {
mscclpp::Timer timer;
Timer timer;
if (listenSocket == NULL) {
throw Error("listenSocket is NULL", ErrorCode::InvalidUsage);

View File

@@ -99,8 +99,8 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
return shared_future;
}
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
EndpointConfig localConfig) {
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(EndpointConfig localConfig,
int remoteRank, int tag) {
auto localEndpoint = context()->createEndpoint(localConfig);
if (remoteRank == bootstrap()->getRank()) {
@@ -134,6 +134,33 @@ MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::co
return shared_future;
}
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
EndpointConfig localConfig) {
return connect(localConfig, remoteRank, tag);
}
MSCCLPP_API_CPP std::shared_future<Semaphore> Communicator::buildSemaphore(std::shared_ptr<Connection> connection,
int remoteRank, int tag) {
SemaphoreStub localStub(connection);
bootstrap()->send(localStub.serialize(), remoteRank, tag);
auto future =
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
localStub = localStub]() mutable {
if (lastRecvItem) {
// Recursive call to the previous receive items
lastRecvItem->wait();
}
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
auto remoteStub = SemaphoreStub::deserialize(data);
return Semaphore(localStub, remoteStub);
});
auto shared_future = std::shared_future<Semaphore>(std::move(future));
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<Semaphore>>(shared_future));
return shared_future;
}
MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
return pimpl_->connectionInfos_.at(&connection).remoteRank;
}

View File

@@ -12,6 +12,7 @@
#include <sstream>
#include <thread>
#include "api.h"
#include "debug.h"
#include "endpoint.hpp"
@@ -32,32 +33,21 @@ std::shared_ptr<RegisteredMemory::Impl> Connection::getImpl(RegisteredMemory& me
std::shared_ptr<Endpoint::Impl> Connection::getImpl(Endpoint& memory) { return memory.pimpl_; }
std::string Connection::getTransportName() const {
return TransportNames[static_cast<int>(this->transport())] + " -> " +
TransportNames[static_cast<int>(this->remoteTransport())];
}
MSCCLPP_API_CPP const Device& Connection::localDevice() const { return localEndpoint_.device(); }
int Connection::getMaxWriteQueueSize() const { return maxWriteQueueSize_; }
MSCCLPP_API_CPP int Connection::getMaxWriteQueueSize() const { return maxWriteQueueSize_; }
// CudaIpcConnection
CudaIpcConnection::CudaIpcConnection(std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint,
std::shared_ptr<CudaIpcStream> stream)
: Connection(context, localEndpoint.maxWriteQueueSize()), stream_(stream) {
if (localEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage);
: Connection(context, localEndpoint), stream_(stream) {
if (localEndpoint.transport() != Transport::CudaIpc || remoteEndpoint.transport() != Transport::CudaIpc) {
throw Error("CudaIpc transport is required for CudaIpcConnection", ErrorCode::InvalidUsage);
}
if (remoteEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Cuda IPC connection can only be made to a Cuda IPC endpoint", ErrorCode::InvalidUsage);
if (localEndpoint.device().type != DeviceType::GPU && remoteEndpoint.device().type != DeviceType::GPU) {
throw Error("CudaIpcConnection requires at least one GPU endpoint", ErrorCode::InvalidUsage);
}
// sanity check: make sure the IPC connection is being made within a node
if (getImpl(remoteEndpoint)->hostHash_ != getImpl(localEndpoint)->hostHash_) {
std::stringstream ss;
ss << "Cuda IPC connection can only be made within a node: " << std::hex << getImpl(remoteEndpoint)->hostHash_
<< " != " << std::hex << getImpl(localEndpoint)->hostHash_;
throw mscclpp::Error(ss.str(), ErrorCode::InvalidUsage);
}
INFO(MSCCLPP_P2P, "Cuda IPC connection created");
}
Transport CudaIpcConnection::transport() const { return Transport::CudaIpc; }
@@ -126,11 +116,13 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
// IBConnection
IBConnection::IBConnection(std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint)
: Connection(context, localEndpoint.maxWriteQueueSize() != -1 ? localEndpoint.maxWriteQueueSize()
: EndpointConfig::DefaultMaxCqSize),
: Connection(context, localEndpoint),
transport_(localEndpoint.transport()),
remoteTransport_(remoteEndpoint.transport()),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
if (maxWriteQueueSize_ == -1) {
maxWriteQueueSize_ = EndpointConfig::DefaultMaxCqSize;
}
qp_ = getImpl(localEndpoint)->ibQp_;
qp_->rtr(getImpl(remoteEndpoint)->ibQpInfo_);
qp_->rts();
@@ -240,13 +232,13 @@ void IBConnection::flush(int64_t timeoutUsec) {
EthernetConnection::EthernetConnection(std::shared_ptr<Context> context, Endpoint localEndpoint,
Endpoint remoteEndpoint, uint64_t sendBufferSize, uint64_t recvBufferSize)
: Connection(context, localEndpoint.maxWriteQueueSize()),
: Connection(context, localEndpoint),
abortFlag_(0),
sendBufferSize_(sendBufferSize),
recvBufferSize_(recvBufferSize) {
// Validating Transport Protocol
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
throw Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
}
// Instanciating Buffers

View File

@@ -13,10 +13,14 @@
namespace mscclpp {
CudaIpcStream::CudaIpcStream() : stream_(std::make_shared<CudaStreamWithFlags>()), dirty_(false) {}
CudaIpcStream::CudaIpcStream(int deviceId)
: stream_(std::make_shared<CudaStreamWithFlags>()), deviceId_(deviceId), dirty_(false) {}
void CudaIpcStream::setStreamIfNeeded() {
if (!env()->cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);
if (!env()->cudaIpcUseDefaultStream && stream_->empty()) {
MSCCLPP_CUDATHROW(cudaSetDevice(deviceId_));
stream_->set(cudaStreamNonBlocking);
}
}
void CudaIpcStream::memcpyD2D(void *dst, const void *src, size_t nbytes) {
@@ -68,29 +72,40 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpo
std::shared_ptr<Connection> conn;
if (localEndpoint.transport() == Transport::CudaIpc) {
if (remoteEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
throw Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
}
int deviceId;
if (localEndpoint.device().type == DeviceType::GPU) {
deviceId = localEndpoint.device().id;
} else if (remoteEndpoint.device().type == DeviceType::GPU) {
deviceId = remoteEndpoint.device().id;
} else {
throw Error("CudaIpc transport requires at least one GPU device", ErrorCode::InvalidUsage);
}
if (deviceId < 0) {
throw Error("No GPU device ID provided", ErrorCode::InvalidUsage);
}
#if defined(MSCCLPP_DEVICE_HIP)
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>());
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>(deviceId));
#else // !defined(MSCCLPP_DEVICE_HIP)
if (pimpl_->ipcStreams_.empty()) {
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>());
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>(deviceId));
}
#endif // !defined(MSCCLPP_DEVICE_HIP)
conn = std::make_shared<CudaIpcConnection>(shared_from_this(), localEndpoint, remoteEndpoint,
pimpl_->ipcStreams_.back());
} else if (AllIBTransports.has(localEndpoint.transport())) {
if (!AllIBTransports.has(remoteEndpoint.transport())) {
throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
throw Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
}
conn = std::make_shared<IBConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
} else if (localEndpoint.transport() == Transport::Ethernet) {
if (remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
throw Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
}
conn = std::make_shared<EthernetConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
} else {
throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError);
throw Error("Unsupported transport", ErrorCode::InternalError);
}
return conn;
}

View File

@@ -93,6 +93,18 @@ const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc | Tran
namespace std {
std::string to_string(const mscclpp::Transport& transport) {
static const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};
return TransportNames[static_cast<size_t>(transport)];
}
std::string to_string(const mscclpp::Device& device) {
std::stringstream ss;
ss << "Device(type=" << to_string(device.type) << ", id=" << device.id << ")";
return ss.str();
}
template <>
struct hash<mscclpp::TransportFlags> {
size_t operator()(const mscclpp::TransportFlags& flags) const {

View File

@@ -7,13 +7,20 @@
#include "api.h"
#include "context.hpp"
#include "serialization.hpp"
#include "socket.h"
#include "utils_internal.hpp"
namespace mscclpp {
Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
: transport_(config.transport), hostHash_(getHostHash()), maxWriteQueueSize_(config.maxWriteQueueSize) {
: transport_(config.transport),
device_(config.device),
hostHash_(getHostHash()),
maxWriteQueueSize_(config.maxWriteQueueSize) {
if (device_.type == DeviceType::GPU && device_.id < 0) {
MSCCLPP_CUDATHROW(cudaGetDevice(&(device_.id)));
}
if (AllIBTransports.has(transport_)) {
ibLocal_ = true;
ibQp_ = contextImpl.getIbContext(transport_)
@@ -23,7 +30,7 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
// Configuring Ethernet Interfaces
abortFlag_ = 0;
int ret = FindInterfaces(netIfName_, &socketAddress_, MAX_IF_NAME_SIZE, 1);
if (ret <= 0) throw Error("NET/Socket", ErrorCode::InternalError);
if (ret <= 0) throw Error("Failed to find network interfaces", ErrorCode::InternalError);
// Starting Server Socket
socket_ = std::make_unique<Socket>(&socketAddress_, MSCCLPP_SOCKET_MAGIC, SocketTypeBootstrap, abortFlag_);
@@ -32,20 +39,38 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
}
}
MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; }
Endpoint::Impl::Impl(const std::vector<char>& serialization) {
auto it = serialization.begin();
it = detail::deserialize(it, transport_);
it = detail::deserialize(it, device_);
it = detail::deserialize(it, hostHash_);
if (AllIBTransports.has(transport_)) {
ibLocal_ = false;
it = detail::deserialize(it, ibQpInfo_);
}
if (transport_ == Transport::Ethernet) {
it = detail::deserialize(it, socketAddress_);
}
}
MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() { return pimpl_->maxWriteQueueSize_; }
MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<Endpoint::Impl> pimpl) : pimpl_(pimpl) {}
MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
MSCCLPP_API_CPP Transport Endpoint::transport() const { return pimpl_->transport_; }
MSCCLPP_API_CPP const Device& Endpoint::device() const { return pimpl_->device_; }
MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() const { return pimpl_->maxWriteQueueSize_; }
MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() const {
std::vector<char> data;
std::copy_n(reinterpret_cast<char*>(&pimpl_->transport_), sizeof(pimpl_->transport_), std::back_inserter(data));
std::copy_n(reinterpret_cast<char*>(&pimpl_->hostHash_), sizeof(pimpl_->hostHash_), std::back_inserter(data));
detail::serialize(data, pimpl_->transport_);
detail::serialize(data, pimpl_->device_);
detail::serialize(data, pimpl_->hostHash_);
if (AllIBTransports.has(pimpl_->transport_)) {
std::copy_n(reinterpret_cast<char*>(&pimpl_->ibQpInfo_), sizeof(pimpl_->ibQpInfo_), std::back_inserter(data));
detail::serialize(data, pimpl_->ibQpInfo_);
}
if ((pimpl_->transport_) == Transport::Ethernet) {
std::copy_n(reinterpret_cast<char*>(&pimpl_->socketAddress_), sizeof(pimpl_->socketAddress_),
std::back_inserter(data));
detail::serialize(data, pimpl_->socketAddress_);
}
return data;
}
@@ -54,23 +79,4 @@ MSCCLPP_API_CPP Endpoint Endpoint::deserialize(const std::vector<char>& data) {
return Endpoint(std::make_shared<Impl>(data));
}
Endpoint::Impl::Impl(const std::vector<char>& serialization) {
auto it = serialization.begin();
std::copy_n(it, sizeof(transport_), reinterpret_cast<char*>(&transport_));
it += sizeof(transport_);
std::copy_n(it, sizeof(hostHash_), reinterpret_cast<char*>(&hostHash_));
it += sizeof(hostHash_);
if (AllIBTransports.has(transport_)) {
ibLocal_ = false;
std::copy_n(it, sizeof(ibQpInfo_), reinterpret_cast<char*>(&ibQpInfo_));
it += sizeof(ibQpInfo_);
}
if (transport_ == Transport::Ethernet) {
std::copy_n(it, sizeof(socketAddress_), reinterpret_cast<char*>(&socketAddress_));
it += sizeof(socketAddress_);
}
}
MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<mscclpp::Endpoint::Impl> pimpl) : pimpl_(pimpl) {}
} // namespace mscclpp

View File

@@ -185,7 +185,7 @@ size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, siz
else if (this->outputChunks.at(rank) != 0)
sizePerRank = outputSize / this->outputChunks.at(rank);
else
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
throw Error("Output or Input chunks must be greater than 0", ErrorCode::ExecutorError);
if (this->isUsingPacket) {
return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/;
@@ -203,7 +203,7 @@ size_t ExecutionPlan::Impl::getMaxScratchBufferSize(int rank) const {
else if (this->outputChunks.at(rank) != 0)
sizePerChunk = maxMessageSize / this->outputChunks.at(rank);
else
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
throw Error("Output or Input chunks must be greater than 0", ErrorCode::ExecutorError);
return this->getScratchBufferSize(rank, sizePerChunk * this->inputChunks.at(rank),
sizePerChunk * this->outputChunks.at(rank));
@@ -414,12 +414,12 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
for (const auto& op : threadblock["ops"]) {
Operation operation = {};
std::vector<uint32_t> chunkIndexes;
operation.type = static_cast<mscclpp::OperationType>(getOpType(op["name"]));
operation.type = static_cast<OperationType>(getOpType(op["name"]));
if (op.contains("ctype")) {
operation.channelType = convertToChannelType(op["ctype"]);
}
if (op.contains("i_cids")) {
if (operation.channelType == mscclpp::ChannelType::NVLS) {
if (operation.channelType == ChannelType::NVLS) {
BufferType srcBufferType = convertToBufferType(op["srcbuff"]);
operation.nvlsInputIndex =
channelIndexes[{srcBufferType, srcBufferType, ChannelType::NVLS}][op["i_cids"][0]["id"]];
@@ -453,7 +453,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
if (op.contains("o_cids")) {
operation.nOutputs = op["o_cids"].size();
for (int i = 0; i < operation.nOutputs; i++) {
if (operation.channelType == mscclpp::ChannelType::NVLS) {
if (operation.channelType == ChannelType::NVLS) {
BufferType dstBufferType = convertToBufferType(op["dstbuff"]);
operation.nvlsOutputIndex =
channelIndexes[{dstBufferType, dstBufferType, ChannelType::NVLS}][op["o_cids"][0]["id"]];
@@ -516,14 +516,13 @@ std::pair<size_t, uint32_t> ExecutionPlan::Impl::getSizeAndChunksForRank(int ran
size_t outputSize) const {
std::pair<size_t, uint32_t> sizePerRank;
if (this->inputChunks.at(rank) == 0 && this->outputChunks.at(rank) == 0) {
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
throw Error("Output or Input chunks must be greater than 0", ErrorCode::ExecutorError);
} else if (this->inputChunks.at(rank) != 0 && this->outputChunks.at(rank) != 0) {
if (inputSize / this->inputChunks.at(rank) != outputSize / this->outputChunks.at(rank))
throw mscclpp::Error("Size per chunks inconsistent: inputSize " + std::to_string(inputSize) + " inputChunks " +
std::to_string(this->inputChunks.at(rank)) + " outputSize " +
std::to_string(outputSize) + " outputChunks " +
std::to_string(this->outputChunks.at(rank)),
mscclpp::ErrorCode::ExecutorError);
throw Error("Size per chunks inconsistent: inputSize " + std::to_string(inputSize) + " inputChunks " +
std::to_string(this->inputChunks.at(rank)) + " outputSize " + std::to_string(outputSize) +
" outputChunks " + std::to_string(this->outputChunks.at(rank)),
ErrorCode::ExecutorError);
else
sizePerRank = std::make_pair(inputSize, this->inputChunks.at(rank));
} else if (this->inputChunks.at(rank) != 0) {

View File

@@ -216,7 +216,7 @@ struct Executor::Impl {
for (int peer : connectedPeers) {
Transport transport =
inSameNode(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
connectionFutures.push_back(this->comm->connect(peer, 0, transport));
connectionFutures.push_back(this->comm->connect(transport, peer));
}
for (size_t i = 0; i < connectionFutures.size(); i++) {
context.connections[connectedPeers[i]] = connectionFutures[i].get();
@@ -263,12 +263,12 @@ struct Executor::Impl {
std::vector<int> connectedPeers = getConnectedPeers(channelInfos);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
for (int peer : connectedPeers) {
comm->sendMemory(memory, peer, 0);
comm->sendMemory(memory, peer);
}
channelInfos = plan.impl_->getChannelInfos(rank, bufferType);
connectedPeers = getConnectedPeers(channelInfos);
for (int peer : connectedPeers) {
remoteRegMemoryFutures.push_back(comm->recvMemory(peer, 0));
remoteRegMemoryFutures.push_back(comm->recvMemory(peer));
}
for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) {
context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get());

View File

@@ -242,8 +242,8 @@ bool isCuMemMapAllocated([[maybe_unused]] void* ptr) {
return false;
}
MSCCLPP_CUTHROW(cuMemRelease(handle));
if (!mscclpp::isNvlsSupported()) {
throw mscclpp::Error("cuMemMap is used in env without NVLS support", mscclpp::ErrorCode::InvalidUsage);
if (!isNvlsSupported()) {
throw Error("cuMemMap is used in env without NVLS support", ErrorCode::InvalidUsage);
}
return true;
#endif

View File

@@ -52,7 +52,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
CUdeviceptr dptr = reinterpret_cast<CUdeviceptr>(buff);
bool cuMemAlloc = mscclpp::isCuMemMapAllocated((void*)dptr);
bool cuMemAlloc = isCuMemMapAllocated((void*)dptr);
int dmaBufSupported = 0;
#if !defined(__HIP_PLATFORM_AMD__)
CUdevice dev;
@@ -71,11 +71,10 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
if (this->mr == nullptr) {
std::stringstream err;
err << "ibv_reg_dmabuf_mr failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
#else
throw mscclpp::Error("Registeration of dma-buf based memory region failed on HIP platform",
ErrorCode::InvalidUsage);
throw Error("Registeration of dma-buf based memory region failed on HIP platform", ErrorCode::InvalidUsage);
#endif // !defined(__HIP_PLATFORM_AMD__)
} else {
this->mr = IBVerbs::ibv_reg_mr2(pd, reinterpret_cast<void*>(addr), pages * pageSize,
@@ -84,7 +83,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
if (this->mr == nullptr) {
std::stringstream err;
err << "ibv_reg_mr failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
}
@@ -111,7 +110,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
if (this->cq == nullptr) {
std::stringstream err;
err << "ibv_create_cq failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
struct ibv_qp_init_attr qpInitAttr;
@@ -130,14 +129,14 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
if (_qp == nullptr) {
std::stringstream err;
err << "ibv_create_qp failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
struct ibv_port_attr portAttr;
if (IBVerbs::ibv_query_port_w(ctx, port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
this->info.lid = portAttr.lid;
this->info.port = port;
@@ -151,7 +150,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
if (IBVerbs::ibv_query_gid(ctx, port, 0, &gid) != 0) {
std::stringstream err;
err << "ibv_query_gid failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
this->info.spn = gid.global.subnet_prefix;
this->info.iid = gid.global.interface_id;
@@ -166,7 +165,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
if (IBVerbs::ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
this->qp = _qp;
this->wrn = 0;
@@ -210,7 +209,7 @@ void IbQp::rtr(const IbQpInfo& info) {
if (ret != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
}
@@ -229,7 +228,7 @@ void IbQp::rts() {
if (ret != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
}
@@ -237,7 +236,7 @@ IbQp::WrInfo IbQp::getNewWrInfo() {
if (this->wrn >= this->maxWrPerSend) {
std::stringstream err;
err << "too many outstanding work requests. limit is " << this->maxWrPerSend;
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
throw Error(err.str(), ErrorCode::InvalidUsage);
}
int wrn = this->wrn;
@@ -306,7 +305,7 @@ void IbQp::postSend() {
if (ret != 0) {
std::stringstream err;
err << "ibv_post_send failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
this->wrn = 0;
this->numSignaledPostedItems += this->numSignaledStagedItems;
@@ -332,7 +331,7 @@ int IbQp::getNumCqItems() const { return this->numSignaledPostedItems; }
IbCtx::IbCtx(const std::string& devName) : devName(devName) {
#if !defined(__HIP_PLATFORM_AMD__)
if (!checkNvPeerMemLoaded()) {
throw mscclpp::Error("nvidia_peermem kernel module is not loaded", ErrorCode::InternalError);
throw Error("nvidia_peermem kernel module is not loaded", ErrorCode::InternalError);
}
#endif // !defined(__HIP_PLATFORM_AMD__)
int num;
@@ -347,13 +346,13 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) {
if (this->ctx == nullptr) {
std::stringstream err;
err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
this->pd = IBVerbs::ibv_alloc_pd(this->ctx);
if (this->pd == nullptr) {
std::stringstream err;
err << "ibv_alloc_pd failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
}
@@ -373,7 +372,7 @@ bool IbCtx::isPortUsable(int port) const {
if (IBVerbs::ibv_query_port_w(this->ctx, port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
return portAttr.state == IBV_PORT_ACTIVE &&
(portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND);
@@ -384,7 +383,7 @@ int IbCtx::getAnyActivePort() const {
if (IBVerbs::ibv_query_device(this->ctx, &devAttr) != 0) {
std::stringstream err;
err << "ibv_query_device failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
throw IbError(err.str(), errno);
}
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
if (this->isPortUsable(port)) {
@@ -399,10 +398,10 @@ IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRec
if (port == -1) {
port = this->getAnyActivePort();
if (port == -1) {
throw mscclpp::Error("No active port found", ErrorCode::InvalidUsage);
throw Error("No active port found", ErrorCode::InvalidUsage);
}
} else if (!this->isPortUsable(port)) {
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage);
throw Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage);
}
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend));
return qps.back().get();

View File

@@ -16,12 +16,13 @@ namespace mscclpp {
class CudaIpcStream {
private:
std::shared_ptr<CudaStreamWithFlags> stream_;
int deviceId_;
bool dirty_;
void setStreamIfNeeded();
public:
CudaIpcStream();
CudaIpcStream(int deviceId);
void memcpyD2D(void *dst, const void *src, size_t nbytes);
@@ -30,6 +31,8 @@ class CudaIpcStream {
void sync();
operator cudaStream_t() const { return *stream_; }
int deviceId() const { return deviceId_; }
};
struct Context::Impl {

View File

@@ -19,6 +19,7 @@ struct Endpoint::Impl {
Impl(const std::vector<char>& serialization);
Transport transport_;
Device device_;
uint64_t hostHash_;
int maxWriteQueueSize_;

View File

@@ -60,6 +60,7 @@ struct RegisteredMemory::Impl {
int fileDesc = -1;
Impl(void* data, size_t size, TransportFlags transports, Context::Impl& contextImpl);
Impl(const std::vector<char>::const_iterator& begin, const std::vector<char>::const_iterator& end);
/// Constructs a RegisteredMemory::Impl from a vector of data. The constructor should only be used for the remote
/// memory.
Impl(const std::vector<char>& data);

View File

@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_SERIALIZATION_HPP_
#define MSCCLPP_SERIALIZATION_HPP_
#include <algorithm>
#include <vector>
namespace mscclpp::detail {
template <typename T>
void serialize(std::vector<char>& buffer, const T& value) {
const char* data = reinterpret_cast<const char*>(&value);
std::copy_n(data, sizeof(T), std::back_inserter(buffer));
}
template <typename T>
std::vector<char>::const_iterator deserialize(const std::vector<char>::const_iterator& pos, T& value) {
std::copy_n(pos, sizeof(T), reinterpret_cast<char*>(&value));
return pos + sizeof(T);
}
} // namespace mscclpp::detail
#endif // MSCCLPP_SERIALIZATION_HPP_

View File

@@ -11,6 +11,9 @@ namespace mscclpp {
MSCCLPP_API_CPP BaseMemoryChannel::BaseMemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore)
: semaphore_(semaphore) {}
MSCCLPP_API_CPP BaseMemoryChannel::BaseMemoryChannel(const Semaphore& semaphore)
: BaseMemoryChannel(std::make_shared<MemoryDevice2DeviceSemaphore>(semaphore)) {}
MSCCLPP_API_CPP MemoryChannel::MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst, void* src, void* packetBuffer)
: BaseMemoryChannel(semaphore), dst_(dst), src_(src), packetBuffer_(packetBuffer) {
@@ -19,6 +22,10 @@ MSCCLPP_API_CPP MemoryChannel::MemoryChannel(std::shared_ptr<MemoryDevice2Device
}
}
MSCCLPP_API_CPP MemoryChannel::MemoryChannel(const Semaphore& semaphore, RegisteredMemory dst, void* src,
void* packetBuffer)
: MemoryChannel(std::make_shared<MemoryDevice2DeviceSemaphore>(semaphore), dst, src, packetBuffer) {}
MSCCLPP_API_CPP BaseMemoryChannel::DeviceHandle BaseMemoryChannel::deviceHandle() const {
return BaseMemoryChannel::DeviceHandle(semaphore_->deviceHandle());
}

View File

@@ -14,10 +14,18 @@ MSCCLPP_API_CPP BasePortChannel::BasePortChannel(SemaphoreId semaphoreId,
std::shared_ptr<Proxy> proxy)
: semaphoreId_(semaphoreId), semaphore_(semaphore), proxy_(proxy) {}
MSCCLPP_API_CPP BasePortChannel::BasePortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore,
std::shared_ptr<Proxy> proxy)
: BasePortChannel(semaphoreId, std::make_shared<Host2DeviceSemaphore>(semaphore), proxy) {}
MSCCLPP_API_CPP PortChannel::PortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore,
std::shared_ptr<Proxy> proxy, MemoryId dst, MemoryId src)
: BasePortChannel(semaphoreId, semaphore, proxy), dst_(dst), src_(src) {}
MSCCLPP_API_CPP PortChannel::PortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore,
std::shared_ptr<Proxy> proxy, MemoryId dst, MemoryId src)
: BasePortChannel(semaphoreId, semaphore, proxy), dst_(dst), src_(src) {}
MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) {
int cudaDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
@@ -39,6 +47,11 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& com
return semaphores_.size() - 1;
}
MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(const Semaphore& semaphore) {
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(semaphore));
return semaphores_.size() - 1;
}
MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Host2DeviceSemaphore> semaphore) {
semaphores_.push_back(semaphore);
return semaphores_.size() - 1;

View File

@@ -12,6 +12,7 @@
#include "api.h"
#include "context.hpp"
#include "debug.h"
#include "serialization.hpp"
#include "utils_internal.hpp"
#define MSCCLPP_CULOG_WARN(cmd) \
@@ -119,43 +120,37 @@ MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() const { return pim
MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() const {
std::vector<char> result;
std::copy_n(reinterpret_cast<char*>(&pimpl_->originalDataPtr), sizeof(pimpl_->originalDataPtr),
std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->size), sizeof(pimpl_->size), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->baseDataSize), sizeof(pimpl_->baseDataSize), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->hostHash), sizeof(pimpl_->hostHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->pidHash), sizeof(pimpl_->pidHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->isCuMemMapAlloc), sizeof(pimpl_->isCuMemMapAlloc),
std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->transports), sizeof(pimpl_->transports), std::back_inserter(result));
detail::serialize(result, pimpl_->originalDataPtr);
detail::serialize(result, pimpl_->size);
detail::serialize(result, pimpl_->baseDataSize);
detail::serialize(result, pimpl_->hostHash);
detail::serialize(result, pimpl_->pidHash);
detail::serialize(result, pimpl_->isCuMemMapAlloc);
detail::serialize(result, pimpl_->transports);
if (pimpl_->transportInfos.size() > static_cast<size_t>(std::numeric_limits<int8_t>::max())) {
throw mscclpp::Error("Too many transport info entries", ErrorCode::InternalError);
throw Error("Too many transport info entries", ErrorCode::InternalError);
}
int8_t transportCount = pimpl_->transportInfos.size();
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
detail::serialize(result, transportCount);
for (auto& entry : pimpl_->transportInfos) {
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
detail::serialize(result, entry.transport);
if (entry.transport == Transport::CudaIpc) {
if (pimpl_->isCuMemMapAlloc) {
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) {
std::copy_n(reinterpret_cast<char*>(&entry.shareableHandle), sizeof(entry.shareableHandle),
std::back_inserter(result));
detail::serialize(result, entry.shareableHandle);
} else {
std::copy_n(reinterpret_cast<char*>(&entry.rootPid), sizeof(entry.rootPid), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&entry.fileDesc), sizeof(entry.fileDesc), std::back_inserter(result));
detail::serialize(result, entry.rootPid);
detail::serialize(result, entry.fileDesc);
}
std::copy_n(reinterpret_cast<char*>(&entry.offsetFromBase), sizeof(entry.offsetFromBase),
std::back_inserter(result));
detail::serialize(result, entry.offsetFromBase);
} else {
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle),
std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase),
std::back_inserter(result));
detail::serialize(result, entry.cudaIpcBaseHandle);
detail::serialize(result, entry.cudaIpcOffsetFromBase);
}
} else if (AllIBTransports.has(entry.transport)) {
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
detail::serialize(result, entry.ibMrInfo);
} else {
throw mscclpp::Error("Unknown transport", ErrorCode::InternalError);
throw Error("Unknown transport", ErrorCode::InternalError);
}
}
return result;
@@ -165,62 +160,44 @@ MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector
return RegisteredMemory(std::make_shared<Impl>(data));
}
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
auto it = serialization.begin();
std::copy_n(it, sizeof(this->originalDataPtr), reinterpret_cast<char*>(&this->originalDataPtr));
it += sizeof(this->originalDataPtr);
std::copy_n(it, sizeof(this->size), reinterpret_cast<char*>(&this->size));
it += sizeof(this->size);
std::copy_n(it, sizeof(this->baseDataSize), reinterpret_cast<char*>(&this->baseDataSize));
it += sizeof(this->baseDataSize);
std::copy_n(it, sizeof(this->hostHash), reinterpret_cast<char*>(&this->hostHash));
it += sizeof(this->hostHash);
std::copy_n(it, sizeof(this->pidHash), reinterpret_cast<char*>(&this->pidHash));
it += sizeof(this->pidHash);
std::copy_n(it, sizeof(this->isCuMemMapAlloc), reinterpret_cast<char*>(&this->isCuMemMapAlloc));
it += sizeof(this->isCuMemMapAlloc);
std::copy_n(it, sizeof(this->transports), reinterpret_cast<char*>(&this->transports));
it += sizeof(this->transports);
RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
const std::vector<char>::const_iterator& end) {
auto it = begin;
it = detail::deserialize(it, this->originalDataPtr);
it = detail::deserialize(it, this->size);
it = detail::deserialize(it, this->baseDataSize);
it = detail::deserialize(it, this->hostHash);
it = detail::deserialize(it, this->pidHash);
it = detail::deserialize(it, this->isCuMemMapAlloc);
it = detail::deserialize(it, this->transports);
int8_t transportCount;
std::copy_n(it, sizeof(transportCount), reinterpret_cast<char*>(&transportCount));
it += sizeof(transportCount);
it = detail::deserialize(it, transportCount);
for (int i = 0; i < transportCount; ++i) {
TransportInfo transportInfo;
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
it += sizeof(transportInfo.transport);
it = detail::deserialize(it, transportInfo.transport);
if (transportInfo.transport == Transport::CudaIpc) {
if (this->isCuMemMapAlloc) {
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) {
std::copy_n(it, sizeof(transportInfo.shareableHandle),
reinterpret_cast<char*>(&transportInfo.shareableHandle));
it += sizeof(transportInfo.shareableHandle);
it = detail::deserialize(it, transportInfo.shareableHandle);
} else {
std::copy_n(it, sizeof(transportInfo.rootPid), reinterpret_cast<char*>(&transportInfo.rootPid));
it += sizeof(transportInfo.rootPid);
std::copy_n(it, sizeof(transportInfo.fileDesc), reinterpret_cast<char*>(&transportInfo.fileDesc));
it += sizeof(transportInfo.fileDesc);
it = detail::deserialize(it, transportInfo.rootPid);
it = detail::deserialize(it, transportInfo.fileDesc);
}
std::copy_n(it, sizeof(transportInfo.offsetFromBase), reinterpret_cast<char*>(&transportInfo.offsetFromBase));
it += sizeof(transportInfo.offsetFromBase);
it = detail::deserialize(it, transportInfo.offsetFromBase);
} else {
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle),
reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
it += sizeof(transportInfo.cudaIpcBaseHandle);
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase),
reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
it += sizeof(transportInfo.cudaIpcOffsetFromBase);
it = detail::deserialize(it, transportInfo.cudaIpcBaseHandle);
it = detail::deserialize(it, transportInfo.cudaIpcOffsetFromBase);
}
} else if (AllIBTransports.has(transportInfo.transport)) {
std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast<char*>(&transportInfo.ibMrInfo));
it += sizeof(transportInfo.ibMrInfo);
it = detail::deserialize(it, transportInfo.ibMrInfo);
transportInfo.ibLocal = false;
} else {
throw mscclpp::Error("Unknown transport", ErrorCode::InternalError);
throw Error("Unknown transport", ErrorCode::InternalError);
}
this->transportInfos.push_back(transportInfo);
}
if (it != serialization.end()) {
throw mscclpp::Error("Serialization failed", ErrorCode::InternalError);
if (it != end) {
throw Error("Serialization failed", ErrorCode::InternalError);
}
// Next decide how to set this->data
@@ -239,11 +216,11 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
} else {
int rootPidFd = syscall(SYS_pidfd_open, entry.rootPid, 0);
if (rootPidFd < 0) {
throw mscclpp::SysError("pidfd_open() failed", errno);
throw SysError("pidfd_open() failed", errno);
}
int fd = syscall(SYS_pidfd_getfd, rootPidFd, entry.fileDesc, 0);
if (fd < 0) {
throw mscclpp::SysError("pidfd_getfd() failed", errno);
throw SysError("pidfd_getfd() failed", errno);
}
INFO(MSCCLPP_P2P, "Get file descriptor %d from pidfd %d on peer 0x%lx", fd, rootPidFd, hostHash);
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, reinterpret_cast<void*>(fd),
@@ -260,9 +237,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
detail::setReadWriteMemoryAccess(base, size);
this->data = static_cast<char*>(base) + entry.offsetFromBase;
#else
throw mscclpp::Error(
"CUDA does not support NVLS. Please ensure your CUDA version supports NVLS to use this feature.",
mscclpp::ErrorCode::InvalidUsage);
throw Error("CUDA does not support NVLS. Please ensure your CUDA version supports NVLS to use this feature.",
ErrorCode::InvalidUsage);
#endif
} else {
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
@@ -275,6 +251,9 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
}
}
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
: Impl(serialization.begin(), serialization.end()) {}
RegisteredMemory::Impl::~Impl() {
// Close the CUDA IPC handle if it was opened during deserialization
if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) {

View File

@@ -5,116 +5,193 @@
#include "api.h"
#include "atomic.hpp"
#include "context.hpp"
#include "debug.h"
#include "registered_memory.hpp"
#include "serialization.hpp"
namespace mscclpp {
static std::shared_future<RegisteredMemory> setupInboundSemaphoreId(Communicator& communicator, Connection* connection,
void* localInboundSemaphoreId) {
auto localInboundSemaphoreIdsRegMem =
communicator.registerMemory(localInboundSemaphoreId, sizeof(uint64_t), connection->transport());
int remoteRank = communicator.remoteRankOf(*connection);
int tag = communicator.tagOf(*connection);
communicator.sendMemory(localInboundSemaphoreIdsRegMem, remoteRank, tag);
return communicator.recvMemory(remoteRank, tag);
struct SemaphoreStub::Impl {
Impl(std::shared_ptr<Connection> connection);
Impl(const RegisteredMemory& idMemory, const Device& device);
Impl(const std::vector<char>& data);
std::shared_ptr<Connection> connection_;
std::shared_ptr<uint64_t> token_;
RegisteredMemory idMemory_;
Device device_;
};
static std::shared_ptr<uint64_t> gpuCallocToken() {
#if defined(MSCCLPP_DEVICE_HIP)
return detail::gpuCallocUncachedShared<uint64_t>();
#else // !defined(MSCCLPP_DEVICE_HIP)
return detail::gpuCallocShared<uint64_t>();
#endif // !defined(MSCCLPP_DEVICE_HIP)
}
static detail::UniqueGpuPtr<uint64_t> createGpuSemaphoreId() {
#if defined(__HIP_PLATFORM_AMD__)
return detail::gpuCallocUncachedUnique<uint64_t>();
#else // !defined(__HIP_PLATFORM_AMD__)
return detail::gpuCallocUnique<uint64_t>();
#endif // !defined(__HIP_PLATFORM_AMD__)
SemaphoreStub::Impl::Impl(std::shared_ptr<Connection> connection) : connection_(connection) {
// Allocate a semaphore ID on the local device
const Device& localDevice = connection_->localDevice();
if (localDevice.type == DeviceType::CPU) {
token_ = std::make_shared<uint64_t>(0);
} else if (localDevice.type == DeviceType::GPU) {
if (localDevice.id < 0) {
throw Error("Local GPU ID is not provided", ErrorCode::InvalidUsage);
}
MSCCLPP_CUDATHROW(cudaSetDevice(localDevice.id));
token_ = gpuCallocToken();
} else {
throw Error("Unsupported local device type", ErrorCode::InvalidUsage);
}
idMemory_ =
std::move(connection->context()->registerMemory(token_.get(), sizeof(uint64_t), connection_->transport()));
}
SemaphoreStub::Impl::Impl(const RegisteredMemory& idMemory, const Device& device)
: idMemory_(idMemory), device_(device) {}
SemaphoreStub::SemaphoreStub(std::shared_ptr<Impl> pimpl) : pimpl_(std::move(pimpl)) {}
MSCCLPP_API_CPP SemaphoreStub::SemaphoreStub(std::shared_ptr<Connection> connection)
: pimpl_(std::make_shared<Impl>(connection)) {}
MSCCLPP_API_CPP std::vector<char> SemaphoreStub::serialize() const {
auto data = pimpl_->idMemory_.serialize();
detail::serialize(data, pimpl_->device_);
return data;
}
MSCCLPP_API_CPP SemaphoreStub SemaphoreStub::deserialize(const std::vector<char>& data) {
Device device;
auto memEnd = data.end() - sizeof(device);
RegisteredMemory idMemory(std::make_shared<RegisteredMemory::Impl>(data.begin(), memEnd));
auto it = detail::deserialize(memEnd, device);
if (it != data.end()) {
throw Error("SemaphoreStub deserialize failed", ErrorCode::InvalidUsage);
}
return SemaphoreStub(std::make_shared<Impl>(std::move(idMemory), device));
}
MSCCLPP_API_CPP const RegisteredMemory& SemaphoreStub::memory() const { return pimpl_->idMemory_; }
struct Semaphore::Impl {
Impl(const SemaphoreStub& localStub, const RegisteredMemory& remoteStubMemory)
: localStub_(localStub), remoteStubMemory_(remoteStubMemory) {}
SemaphoreStub localStub_;
RegisteredMemory remoteStubMemory_;
};
Semaphore::Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub)
: pimpl_(std::make_unique<Impl>(localStub, remoteStub.memory())) {}
MSCCLPP_API_CPP std::shared_ptr<Connection> Semaphore::connection() const {
return pimpl_->localStub_.pimpl_->connection_;
}
MSCCLPP_API_CPP const RegisteredMemory& Semaphore::localMemory() const { return pimpl_->localStub_.memory(); }
MSCCLPP_API_CPP const RegisteredMemory& Semaphore::remoteMemory() const { return pimpl_->remoteStubMemory_; }
static Semaphore buildSemaphoreFromConnection(Communicator& communicator, std::shared_ptr<Connection> connection) {
auto semaphoreFuture =
communicator.buildSemaphore(connection, communicator.remoteRankOf(*connection), communicator.tagOf(*connection));
return semaphoreFuture.get();
}
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(const Semaphore& semaphore)
: semaphore_(semaphore),
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()),
outboundToken_(std::make_unique<uint64_t>()) {
if (connection()->localDevice().type != DeviceType::GPU) {
throw Error("Local endpoint device type of Host2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage);
}
}
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection)
: BaseSemaphore(createGpuSemaphoreId(), createGpuSemaphoreId(), std::make_unique<uint64_t>()),
connection_(connection) {
INFO(MSCCLPP_INIT, "Creating a Host2Device semaphore for %s transport from %d to %d",
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
communicator.remoteRankOf(*connection));
remoteInboundSemaphoreIdsRegMem_ =
setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get());
}
: Host2DeviceSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2DeviceSemaphore::connection() { return connection_; }
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2DeviceSemaphore::connection() const { return semaphore_.connection(); }
MSCCLPP_API_CPP void Host2DeviceSemaphore::signal() {
connection_->updateAndSync(remoteInboundSemaphoreIdsRegMem_.get(), 0, outboundSemaphore_.get(),
*outboundSemaphore_ + 1);
connection()->updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
}
MSCCLPP_API_CPP Host2DeviceSemaphore::DeviceHandle Host2DeviceSemaphore::deviceHandle() {
MSCCLPP_API_CPP Host2DeviceSemaphore::DeviceHandle Host2DeviceSemaphore::deviceHandle() const {
Host2DeviceSemaphore::DeviceHandle device;
device.inboundSemaphoreId = localInboundSemaphore_.get();
device.expectedInboundSemaphoreId = expectedInboundSemaphore_.get();
device.inboundToken = reinterpret_cast<uint64_t*>(semaphore_.localMemory().data());
device.expectedInboundToken = expectedInboundToken_.get();
return device;
}
MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(const Semaphore& semaphore)
: semaphore_(semaphore),
expectedInboundToken_(std::make_unique<uint64_t>()),
outboundToken_(std::make_unique<uint64_t>()) {
if (connection()->transport() == Transport::CudaIpc) {
throw Error("Host2HostSemaphore cannot be used with CudaIpc transport", ErrorCode::InvalidUsage);
}
if (connection()->localDevice().type != DeviceType::CPU) {
throw Error("Local endpoint device type of Host2HostSemaphore should be CPU", ErrorCode::InvalidUsage);
}
}
MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection)
: BaseSemaphore(std::make_unique<uint64_t>(), std::make_unique<uint64_t>(), std::make_unique<uint64_t>()),
connection_(connection) {
INFO(MSCCLPP_INIT, "Creating a Host2Host semaphore for %s transport from %d to %d",
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
communicator.remoteRankOf(*connection));
: Host2HostSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
if (connection->transport() == Transport::CudaIpc) {
throw Error("Host2HostSemaphore cannot be used with CudaIpc transport", ErrorCode::InvalidUsage);
}
remoteInboundSemaphoreIdsRegMem_ =
setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get());
}
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2HostSemaphore::connection() { return connection_; }
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2HostSemaphore::connection() const { return semaphore_.connection(); }
MSCCLPP_API_CPP void Host2HostSemaphore::signal() {
connection_->updateAndSync(remoteInboundSemaphoreIdsRegMem_.get(), 0, outboundSemaphore_.get(),
*outboundSemaphore_ + 1);
connection()->updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
}
MSCCLPP_API_CPP bool Host2HostSemaphore::poll() {
bool signaled =
(atomicLoad((uint64_t*)localInboundSemaphore_.get(), memoryOrderAcquire) > (*expectedInboundSemaphore_));
if (signaled) (*expectedInboundSemaphore_) += 1;
bool signaled = (atomicLoad(reinterpret_cast<uint64_t*>(semaphore_.localMemory().data()), memoryOrderAcquire) >
(*expectedInboundToken_));
if (signaled) (*expectedInboundToken_) += 1;
return signaled;
}
MSCCLPP_API_CPP void Host2HostSemaphore::wait(int64_t maxSpinCount) {
(*expectedInboundSemaphore_) += 1;
(*expectedInboundToken_) += 1;
int64_t spinCount = 0;
while (atomicLoad((uint64_t*)localInboundSemaphore_.get(), memoryOrderAcquire) < (*expectedInboundSemaphore_)) {
while (atomicLoad(reinterpret_cast<uint64_t*>(semaphore_.localMemory().data()), memoryOrderAcquire) <
(*expectedInboundToken_)) {
if (maxSpinCount >= 0 && spinCount++ == maxSpinCount) {
throw Error("Host2HostSemaphore::wait timed out", ErrorCode::Timeout);
}
}
}
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(const Semaphore& semaphore)
: semaphore_(semaphore),
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()),
outboundToken_(detail::gpuCallocUnique<uint64_t>()) {
if (connection()->localDevice().type != DeviceType::GPU) {
throw Error("Local endpoint device type of MemoryDevice2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage);
}
}
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection)
: BaseSemaphore(createGpuSemaphoreId(), createGpuSemaphoreId(), createGpuSemaphoreId()) {
INFO(MSCCLPP_INIT, "Creating a Device2Device semaphore for %s transport from %d to %d",
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
communicator.remoteRankOf(*connection));
if (connection->transport() == Transport::CudaIpc) {
remoteInboundSemaphoreIdsRegMem_ =
setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get());
isRemoteInboundSemaphoreIdSet_ = true;
} else if (AllIBTransports.has(connection->transport())) {
// Should we throw an error here?
isRemoteInboundSemaphoreIdSet_ = false;
}
: MemoryDevice2DeviceSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
MSCCLPP_API_CPP std::shared_ptr<Connection> MemoryDevice2DeviceSemaphore::connection() const {
return semaphore_.connection();
}
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::DeviceHandle MemoryDevice2DeviceSemaphore::deviceHandle() const {
MemoryDevice2DeviceSemaphore::DeviceHandle device;
device.remoteInboundSemaphoreId = isRemoteInboundSemaphoreIdSet_
? reinterpret_cast<uint64_t*>(remoteInboundSemaphoreIdsRegMem_.get().data())
: nullptr;
device.inboundSemaphoreId = localInboundSemaphore_.get();
device.expectedInboundSemaphoreId = expectedInboundSemaphore_.get();
device.outboundSemaphoreId = outboundSemaphore_.get();
device.remoteInboundToken = reinterpret_cast<uint64_t*>(semaphore_.remoteMemory().data());
device.inboundToken = reinterpret_cast<uint64_t*>(semaphore_.localMemory().data());
device.expectedInboundToken = expectedInboundToken_.get();
device.outboundToken = outboundToken_.get();
return device;
};

View File

@@ -225,16 +225,17 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
transport = ibTransport;
}
// Connect with all other ranks
connections[r] = comm.connect(r, 0, transport);
connections[r] = comm.connect(transport, r);
auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
localMemories.push_back(memory);
comm.sendMemory(memory, r, 0);
remoteMemories.push_back(comm.recvMemory(r, 0));
comm.sendMemory(memory, r);
remoteMemories.push_back(comm.recvMemory(r));
}
for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get()));
auto sema = comm.buildSemaphore(connections[r].get(), r).get();
semaphoreIds.push_back(proxyService.addSemaphore(sema));
}
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;

View File

@@ -83,7 +83,6 @@ class MyProxyService {
int dataSize_;
std::vector<mscclpp::RegisteredMemory> remoteMemories_;
mscclpp::RegisteredMemory localMemory_;
std::vector<std::shared_ptr<mscclpp::Host2HostSemaphore>> hostSemaphores_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> deviceSemaphores1_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> deviceSemaphores2_;
std::vector<std::shared_ptr<mscclpp::Connection>> connections_;
@@ -105,7 +104,6 @@ class MyProxyService {
localMemory_ = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
for (int r = 0; r < world_size; ++r) {
if (r == rank) {
hostSemaphores_.emplace_back(nullptr);
deviceSemaphores1_.emplace_back(nullptr);
deviceSemaphores2_.emplace_back(nullptr);
continue;
@@ -117,10 +115,10 @@ class MyProxyService {
transport = ibTransport;
}
// Connect with all other ranks
connectionsFuture[r] = comm.connect(r, 0, transport);
comm.sendMemory(localMemory_, r, 0);
connectionsFuture[r] = comm.connect(transport, r);
comm.sendMemory(localMemory_, r);
remoteMemoriesFuture[r] = comm.recvMemory(r, 0);
remoteMemoriesFuture[r] = comm.recvMemory(r);
}
for (int r = 0; r < world_size; ++r) {
@@ -128,11 +126,6 @@ class MyProxyService {
continue;
}
connections_[r] = connectionsFuture[r].get();
if (rankToNode(r) == thisNode) {
hostSemaphores_.emplace_back(nullptr);
} else {
hostSemaphores_.emplace_back(std::make_shared<mscclpp::Host2HostSemaphore>(comm, connections_[r]));
}
deviceSemaphores1_.emplace_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(comm, connections_[r]));
deviceSemaphores2_.emplace_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(comm, connections_[r]));
remoteMemories_[r] = remoteMemoriesFuture[r].get();

View File

@@ -47,11 +47,11 @@ void CommunicatorTestBase::connectMesh(bool useIpc, bool useIb, bool useEthernet
for (int i = 0; i < numRanksToUse; i++) {
if (i != gEnv->rank) {
if ((rankToNode(i) == rankToNode(gEnv->rank)) && useIpc) {
connectionFutures[i] = communicator->connect(i, 0, mscclpp::Transport::CudaIpc);
connectionFutures[i] = communicator->connect(mscclpp::Transport::CudaIpc, i);
} else if (useIb) {
connectionFutures[i] = communicator->connect(i, 0, ibTransport);
connectionFutures[i] = communicator->connect(ibTransport, i);
} else if (useEthernet) {
connectionFutures[i] = communicator->connect(i, 0, mscclpp::Transport::Ethernet);
connectionFutures[i] = communicator->connect(mscclpp::Transport::Ethernet, i);
}
}
}

View File

@@ -39,28 +39,26 @@ void MemoryChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::Memory
continue;
}
if (rankToNode(r) == rankToNode(gEnv->rank)) {
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::CudaIpc);
connectionFutures[r] = communicator->connect(mscclpp::Transport::CudaIpc, r);
} else {
connectionFutures[r] = communicator->connect(r, 0, ibTransport);
connectionFutures[r] = communicator->connect(ibTransport, r);
}
if (isInPlace) {
communicator->sendMemory(inputBufRegMem, r, 0);
communicator->sendMemory(inputBufRegMem, r);
} else {
communicator->sendMemory(outputBufRegMem, r, 0);
communicator->sendMemory(outputBufRegMem, r);
}
remoteMemFutures[r] = communicator->recvMemory(r, 0);
remoteMemFutures[r] = communicator->recvMemory(r);
}
for (int r = 0; r < worldSize; r++) {
if (r == rank) {
continue;
}
connections[r] = connectionFutures[r].get();
auto sema = communicator->buildSemaphore(connectionFutures[r].get(), r).get();
memorySemaphores[r] = std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*communicator, connections[r]);
memoryChannels.emplace_back(memorySemaphores[r], remoteMemFutures[r].get(), inputBufRegMem.data(),
memoryChannels.emplace_back(sema, remoteMemFutures[r].get(), inputBufRegMem.data(),
(isInPlace ? nullptr : outputBufRegMem.data()));
}
}

View File

@@ -42,26 +42,28 @@ void PortChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::PortChan
continue;
}
if ((rankToNode(r) == rankToNode(gEnv->rank)) && useIPC) {
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::CudaIpc);
connectionFutures[r] = communicator->connect(mscclpp::Transport::CudaIpc, r);
} else if (useIb) {
connectionFutures[r] = communicator->connect(r, 0, ibTransport);
connectionFutures[r] = communicator->connect(ibTransport, r);
} else if (useEthernet) {
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::Ethernet);
connectionFutures[r] = communicator->connect(mscclpp::Transport::Ethernet, r);
}
if (isInPlace) {
communicator->sendMemory(sendBufRegMem, r, 0);
communicator->sendMemory(sendBufRegMem, r);
} else {
communicator->sendMemory(recvBufRegMem, r, 0);
communicator->sendMemory(recvBufRegMem, r);
}
remoteMemFutures[r] = communicator->recvMemory(r, 0);
remoteMemFutures[r] = communicator->recvMemory(r);
}
for (int r = 0; r < worldSize; r++) {
if (r == rank) {
continue;
}
mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(*communicator, connectionFutures[r].get());
auto sema = communicator->buildSemaphore(connectionFutures[r].get(), r).get();
mscclpp::SemaphoreId cid = proxyService->addSemaphore(sema);
portChannels.emplace_back(proxyService->portChannel(cid, proxyService->addMemory(remoteMemFutures[r].get()),
proxyService->addMemory(sendBufRegMem)));

View File

@@ -386,10 +386,10 @@ void BaseTestEngine::setupMeshConnectionsInternal(
transport = ibTransport;
}
// Connect with all other ranks
connectionFutures.push_back(comm_->connect(r, 0, transport));
connectionFutures.push_back(comm_->connect(transport, r));
}
comm_->sendMemory(localRegMemory, r, 0);
auto remoteMemory = comm_->recvMemory(r, 0);
comm_->sendMemory(localRegMemory, r);
auto remoteMemory = comm_->recvMemory(r);
remoteRegMemories.push_back(remoteMemory);
}
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),

View File

@@ -153,18 +153,20 @@ void SendRecvTestEngine::setupConnections() {
std::array<int, 2> ranks = {sendToRank, recvFromRank};
auto service = std::dynamic_pointer_cast<mscclpp::ProxyService>(chanService_);
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
auto sendConnFuture =
comm_->connect(sendToRank, 0, getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice));
std::vector<mscclpp::Semaphore> semaphores;
if (recvFromRank != sendToRank) {
auto recvConnFuture =
comm_->connect(recvFromRank, 0, getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice));
memorySemaphores.push_back(std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, sendConnFuture.get()));
memorySemaphores.push_back(std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, recvConnFuture.get()));
auto sendTransport = getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice);
auto recvTransport = getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice);
auto connFutures = {comm_->connect(sendTransport, sendToRank), comm_->connect(recvTransport, recvFromRank)};
auto semaFutures = {comm_->buildSemaphore(connFutures.begin()->get(), sendToRank),
comm_->buildSemaphore((connFutures.begin() + 1)->get(), recvFromRank)};
semaphores.emplace_back(semaFutures.begin()->get());
semaphores.emplace_back((semaFutures.begin() + 1)->get());
} else {
memorySemaphores.push_back(std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, sendConnFuture.get()));
memorySemaphores.push_back(memorySemaphores[0]); // reuse the send channel if worldSize is 2
auto sendTransport = getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice);
auto connFuture = comm_->connect(sendTransport, sendToRank);
semaphores.emplace_back(comm_->buildSemaphore(connFuture.get(), sendToRank).get());
semaphores.emplace_back(semaphores[0]); // reuse the semaphore if worldSize is 2
}
std::vector<mscclpp::RegisteredMemory> localMemories;
@@ -172,9 +174,9 @@ void SendRecvTestEngine::setupConnections() {
for (int i : {0, 1}) {
auto regMem = comm_->registerMemory(devicePtrs_[i].get(), args_.maxBytes, mscclpp::Transport::CudaIpc | ibDevice);
comm_->sendMemory(regMem, ranks[i], 0);
comm_->sendMemory(regMem, ranks[i]);
localMemories.push_back(regMem);
futureRemoteMemory.push_back(comm_->recvMemory(ranks[1 - i], 0));
futureRemoteMemory.push_back(comm_->recvMemory(ranks[1 - i]));
}
// swap to make sure devicePtrs_[0] in local rank write to devicePtrs_[1] in remote rank
@@ -182,7 +184,7 @@ void SendRecvTestEngine::setupConnections() {
std::vector<DeviceHandle<mscclpp::MemoryChannel>> memoryChannelHandles(2);
for (int i : {0, 1}) {
// We assume ranks in the same node
memoryChannels_.emplace_back(memorySemaphores[i], futureRemoteMemory[i].get(), (void*)localMemories[i].data());
memoryChannels_.emplace_back(semaphores[i], futureRemoteMemory[i].get(), (void*)localMemories[i].data());
}
std::transform(memoryChannels_.begin(), memoryChannels_.end(), memoryChannelHandles.begin(),
[](const mscclpp::MemoryChannel& memoryChannel) { return memoryChannel.deviceHandle(); });

View File

@@ -29,8 +29,8 @@ TEST_F(LocalCommunicatorTest, RegisterMemory) {
TEST_F(LocalCommunicatorTest, SendMemoryToSelf) {
int dummy[42];
auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports);
comm->sendMemory(memory, 0, 0);
auto memoryFuture = comm->recvMemory(0, 0);
comm->sendMemory(memory, 0);
auto memoryFuture = comm->recvMemory(0);
auto sameMemory = memoryFuture.get();
EXPECT_EQ(sameMemory.data(), memory.data());
EXPECT_EQ(sameMemory.size(), memory.size());

View File

@@ -46,7 +46,7 @@ static void localPortChannelTest(mscclpp::Transport transport) {
bootstrap->initialize(mscclpp::TcpBootstrap::createUniqueId());
auto communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
auto connection = communicator->connect(/*remoteRank*/ 0, /*tag*/ 0, transport).get();
auto connection = communicator->connect(transport, /*remoteRank*/ 0).get();
const size_t bytes = 4 * 1024 * 1024;
auto srcBuff = mscclpp::GpuBuffer(bytes).memory();