mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
New semaphore constructors (#559)
More intuitive interfaces for creating semaphores and channels. Also allows channel construction using third-party bootstrappers directly without overriding MSCCL++ Bootstrap.
This commit is contained in:
@@ -275,8 +275,8 @@ static std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_pt
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
|
||||
for (int i = 0; i < comm->bootstrap()->getNranks(); i++) {
|
||||
if (i == rank) continue;
|
||||
remoteRegMemoryFutures.push_back(comm->recvMemory(i, 0));
|
||||
comm->sendMemory(memory, i, 0);
|
||||
remoteRegMemoryFutures.push_back(comm->recvMemory(i));
|
||||
comm->sendMemory(memory, i);
|
||||
}
|
||||
std::transform(remoteRegMemoryFutures.begin(), remoteRegMemoryFutures.end(), std::back_inserter(remoteMemories),
|
||||
[](const auto& future) { return future.get(); });
|
||||
@@ -613,7 +613,7 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
|
||||
for (int i = 0; i < mscclppComm->bootstrap()->getNranks(); i++) {
|
||||
if (i == rank) continue;
|
||||
mscclpp::Transport transport = getTransport(rank, i);
|
||||
connectionFutures.push_back(mscclppComm->connect(i, 0, transport));
|
||||
connectionFutures.push_back(mscclppComm->connect(transport, i));
|
||||
}
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
|
||||
@@ -39,16 +39,17 @@ void setupMeshTopology(int rank, int worldsize, void* data, size_t dataSize) {
|
||||
if (r == rank) continue;
|
||||
mscclpp::Transport transport = mscclpp::Transport::CudaIpc;
|
||||
// Connect with all other ranks
|
||||
connections[r] = comm.connect(r, 0, transport);
|
||||
connections[r] = comm.connect(transport, r);
|
||||
auto memory = comm.registerMemory(data, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
localMemories.push_back(memory);
|
||||
comm.sendMemory(memory, r, 0);
|
||||
remoteMemories.push_back(comm.recvMemory(r, 0));
|
||||
comm.sendMemory(memory, r);
|
||||
remoteMemories.push_back(comm.recvMemory(r));
|
||||
}
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) continue;
|
||||
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get()));
|
||||
auto sema = communicator->buildSemaphore(connections[r].get(), r).get();
|
||||
semaphoreIds.push_back(proxyService->addSemaphore(sema));
|
||||
}
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;
|
||||
|
||||
@@ -21,11 +21,11 @@ namespace mscclpp {
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
|
||||
/// Unique ID for a process. This is a MSCCLPP_UNIQUE_ID_BYTES byte array that uniquely identifies a process.
|
||||
/// Unique ID for initializing the TcpBootstrap.
|
||||
using UniqueId = std::array<uint8_t, MSCCLPP_UNIQUE_ID_BYTES>;
|
||||
|
||||
/// Return a version string.
|
||||
/// @return A string representing the version of MSCCL++ in the format "major.minor.patch".
|
||||
/// @return The MSCCL++ version string in "major.minor.patch" format.
|
||||
std::string version();
|
||||
|
||||
/// Base class for bootstraps.
|
||||
@@ -220,9 +220,6 @@ enum class Transport {
|
||||
NumTransports, // The number of transports.
|
||||
};
|
||||
|
||||
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
|
||||
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};
|
||||
|
||||
namespace detail {
|
||||
const size_t TransportFlagsSize = 12;
|
||||
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
|
||||
@@ -234,119 +231,99 @@ using TransportFlagsBase = std::bitset<TransportFlagsSize>;
|
||||
/// Stores transport flags.
|
||||
class TransportFlags : private detail::TransportFlagsBase {
|
||||
public:
|
||||
/// Default constructor for TransportFlags.
|
||||
/// Constructor.
|
||||
TransportFlags() = default;
|
||||
|
||||
/// Constructor for TransportFlags that takes a Transport enum value.
|
||||
///
|
||||
/// Constructor.
|
||||
/// @param transport The transport to set the flag for.
|
||||
TransportFlags(Transport transport);
|
||||
|
||||
/// Check if a specific transport flag is set.
|
||||
///
|
||||
/// @param transport The transport to check the flag for.
|
||||
/// @return True if the flag is set, false otherwise.
|
||||
bool has(Transport transport) const;
|
||||
|
||||
/// Check if no transport flags are set.
|
||||
///
|
||||
/// @return True if no flags are set, false otherwise.
|
||||
bool none() const;
|
||||
|
||||
/// Check if any transport flags are set.
|
||||
///
|
||||
/// @return True if any flags are set, false otherwise.
|
||||
bool any() const;
|
||||
|
||||
/// Check if all transport flags are set.
|
||||
///
|
||||
/// @return True if all flags are set, false otherwise.
|
||||
bool all() const;
|
||||
|
||||
/// Get the number of transport flags that are set.
|
||||
///
|
||||
/// @return The number of flags that are set.
|
||||
size_t count() const;
|
||||
|
||||
/// Bitwise OR assignment operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the OR operation with.
|
||||
/// @return A reference to the modified TransportFlags.
|
||||
TransportFlags& operator|=(TransportFlags other);
|
||||
|
||||
/// Bitwise OR operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the OR operation with.
|
||||
/// @return A new TransportFlags object with the result of the OR operation.
|
||||
TransportFlags operator|(TransportFlags other) const;
|
||||
|
||||
/// Bitwise OR operator for TransportFlags and Transport.
|
||||
///
|
||||
/// @param transport The Transport to perform the OR operation with.
|
||||
/// @return A new TransportFlags object with the result of the OR operation.
|
||||
TransportFlags operator|(Transport transport) const;
|
||||
|
||||
/// Bitwise AND assignment operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the AND operation with.
|
||||
/// @return A reference to the modified TransportFlags.
|
||||
TransportFlags& operator&=(TransportFlags other);
|
||||
|
||||
/// Bitwise AND operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the AND operation with.
|
||||
/// @return A new TransportFlags object with the result of the AND operation.
|
||||
TransportFlags operator&(TransportFlags other) const;
|
||||
|
||||
/// Bitwise AND operator for TransportFlags and Transport.
|
||||
///
|
||||
/// @param transport The Transport to perform the AND operation with.
|
||||
/// @return A new TransportFlags object with the result of the AND operation.
|
||||
TransportFlags operator&(Transport transport) const;
|
||||
|
||||
/// Bitwise XOR assignment operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the XOR operation with.
|
||||
/// @return A reference to the modified TransportFlags.
|
||||
TransportFlags& operator^=(TransportFlags other);
|
||||
|
||||
/// Bitwise XOR operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to perform the XOR operation with.
|
||||
/// @return A new TransportFlags object with the result of the XOR operation.
|
||||
TransportFlags operator^(TransportFlags other) const;
|
||||
|
||||
/// Bitwise XOR operator for TransportFlags and Transport.
|
||||
///
|
||||
/// @param transport The Transport to perform the XOR operation with.
|
||||
/// @return A new TransportFlags object with the result of the XOR operation.
|
||||
TransportFlags operator^(Transport transport) const;
|
||||
|
||||
/// Bitwise NOT operator for TransportFlags.
|
||||
///
|
||||
/// @return A new TransportFlags object with the result of the NOT operation.
|
||||
TransportFlags operator~() const;
|
||||
|
||||
/// Equality comparison operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to compare with.
|
||||
/// @return True if the two TransportFlags objects are equal, false otherwise.
|
||||
bool operator==(TransportFlags other) const;
|
||||
|
||||
/// Inequality comparison operator for TransportFlags.
|
||||
///
|
||||
/// @param other The other TransportFlags to compare with.
|
||||
/// @return True if the two TransportFlags objects are not equal, false otherwise.
|
||||
bool operator!=(TransportFlags other) const;
|
||||
|
||||
/// Convert the TransportFlags object to a bitset representation.
|
||||
///
|
||||
/// @return A detail::TransportFlagsBase object representing the TransportFlags object.
|
||||
detail::TransportFlagsBase toBitset() const;
|
||||
|
||||
private:
|
||||
/// Private constructor for TransportFlags that takes a bitset representation.
|
||||
///
|
||||
/// @param bitset The bitset representation of the TransportFlags object.
|
||||
TransportFlags(detail::TransportFlagsBase bitset);
|
||||
};
|
||||
@@ -378,47 +355,64 @@ inline TransportFlags operator^(Transport transport1, Transport transport2) {
|
||||
return TransportFlags(transport1) ^ transport2;
|
||||
}
|
||||
|
||||
/// Available device types.
|
||||
enum class DeviceType {
|
||||
Unknown, // Unknown device type.
|
||||
CPU, // CPU device type.
|
||||
GPU, // GPU device type.
|
||||
};
|
||||
|
||||
struct Device {
|
||||
/// Constructor.
|
||||
Device() = default;
|
||||
|
||||
/// Constructor.
|
||||
/// @param type Device type.
|
||||
/// @param id Device ID. Default is -1 (no specific ID).
|
||||
Device(DeviceType type, int id = -1) : type(type), id(id) {}
|
||||
|
||||
/// Device Type.
|
||||
DeviceType type;
|
||||
|
||||
/// Device ID.
|
||||
int id;
|
||||
};
|
||||
|
||||
class Context;
|
||||
class Connection;
|
||||
|
||||
/// Represents a block of memory that has been registered to a Context.
|
||||
/// Block of memory that has been registered to a Context.
|
||||
/// RegisteredMemory does not own the memory it points to, but it provides a way to transfer metadata about the memory
|
||||
/// to other processes, hence allowing their access to the memory block.
|
||||
class RegisteredMemory {
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
RegisteredMemory() = default;
|
||||
|
||||
/// Destructor.
|
||||
~RegisteredMemory();
|
||||
|
||||
/// Get a pointer to the memory block.
|
||||
///
|
||||
/// @return A pointer to the memory block.
|
||||
void* data() const;
|
||||
|
||||
/// Get a pointer to the original memory block.
|
||||
///
|
||||
/// @return A pointer to the original memory block.
|
||||
void* originalDataPtr() const;
|
||||
|
||||
/// Get the size of the memory block.
|
||||
///
|
||||
/// @return The size of the memory block.
|
||||
size_t size() const;
|
||||
|
||||
/// Get the transport flags associated with the memory block.
|
||||
///
|
||||
/// @return The transport flags associated with the memory block.
|
||||
TransportFlags transports() const;
|
||||
|
||||
/// Serialize the RegisteredMemory object to a vector of characters.
|
||||
///
|
||||
/// @return A vector of characters representing the serialized RegisteredMemory object.
|
||||
std::vector<char> serialize() const;
|
||||
|
||||
/// Deserialize a RegisteredMemory object from a vector of characters.
|
||||
///
|
||||
/// @param data A vector of characters representing a serialized RegisteredMemory object.
|
||||
/// @return A deserialized RegisteredMemory object.
|
||||
static RegisteredMemory deserialize(const std::vector<char>& data);
|
||||
@@ -430,31 +424,32 @@ class RegisteredMemory {
|
||||
|
||||
friend class Context;
|
||||
friend class Connection;
|
||||
friend class SemaphoreStub;
|
||||
};
|
||||
|
||||
/// Represents one end of a connection.
|
||||
/// One end of a connection.
|
||||
class Endpoint {
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
Endpoint() = default;
|
||||
|
||||
/// Get the transport used.
|
||||
///
|
||||
/// @return The transport used.
|
||||
Transport transport();
|
||||
Transport transport() const;
|
||||
|
||||
/// Get the device used.
|
||||
/// @return The device used.
|
||||
const Device& device() const;
|
||||
|
||||
/// Get the maximum write queue size.
|
||||
///
|
||||
/// @return The maximum number of write requests that can be queued.
|
||||
int maxWriteQueueSize();
|
||||
int maxWriteQueueSize() const;
|
||||
|
||||
/// Serialize the Endpoint object to a vector of characters.
|
||||
///
|
||||
/// @return A vector of characters representing the serialized Endpoint object.
|
||||
std::vector<char> serialize();
|
||||
std::vector<char> serialize() const;
|
||||
|
||||
/// Deserialize an Endpoint object from a vector of characters.
|
||||
///
|
||||
/// @param data A vector of characters representing a serialized Endpoint object.
|
||||
/// @return A deserialized Endpoint object.
|
||||
static Endpoint deserialize(const std::vector<char>& data);
|
||||
@@ -468,13 +463,13 @@ class Endpoint {
|
||||
friend class Connection;
|
||||
};
|
||||
|
||||
/// Represents a connection between two processes.
|
||||
/// Connection between two processes.
|
||||
class Connection {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param maxWriteQueueSize The maximum number of write requests that can be queued.
|
||||
Connection(std::shared_ptr<Context> context, int maxWriteQueueSize)
|
||||
: context_(context), maxWriteQueueSize_(maxWriteQueueSize){};
|
||||
/// @param localEndpoint The local endpoint of the connection.
|
||||
Connection(std::shared_ptr<Context> context, const Endpoint& localEndpoint)
|
||||
: context_(context), localEndpoint_(localEndpoint), maxWriteQueueSize_(localEndpoint.maxWriteQueueSize()) {}
|
||||
|
||||
/// Destructor.
|
||||
virtual ~Connection() = default;
|
||||
@@ -498,34 +493,35 @@ class Connection {
|
||||
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
|
||||
|
||||
/// Flush any pending writes to the remote process.
|
||||
virtual void flush(int64_t timeoutUsec = 3e7) = 0;
|
||||
/// @param timeoutUsec Timeout in microseconds. Default: -1 (no timeout)
|
||||
virtual void flush(int64_t timeoutUsec = -1) = 0;
|
||||
|
||||
/// Get the transport used by the local process.
|
||||
///
|
||||
/// @return The transport used by the local process.
|
||||
virtual Transport transport() const = 0;
|
||||
|
||||
/// Get the transport used by the remote process.
|
||||
///
|
||||
/// @return The transport used by the remote process.
|
||||
virtual Transport remoteTransport() const = 0;
|
||||
|
||||
/// Get the name of the transport used for this connection
|
||||
///
|
||||
/// @return A string formatted as "localTransportName -> remoteTransportName".
|
||||
std::string getTransportName() const;
|
||||
/// Get the context associated with this connection.
|
||||
/// @return A shared pointer to the context associated with this connection.
|
||||
std::shared_ptr<Context> context() const { return context_; }
|
||||
|
||||
/// Get the maximum write queue size
|
||||
///
|
||||
/// Get the device used by the local endpoint.
|
||||
/// @return The device used by the local endpoint.
|
||||
const Device& localDevice() const;
|
||||
|
||||
/// Get the maximum write queue size.
|
||||
/// @return The maximum number of write requests that can be queued.
|
||||
int getMaxWriteQueueSize() const;
|
||||
|
||||
protected:
|
||||
// Internal methods for getting implementation pointers.
|
||||
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
|
||||
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
|
||||
|
||||
std::shared_ptr<Context> context_;
|
||||
Endpoint localEndpoint_;
|
||||
int maxWriteQueueSize_;
|
||||
};
|
||||
|
||||
@@ -537,6 +533,7 @@ struct EndpointConfig {
|
||||
static const int DefaultMaxWrPerSend = 64;
|
||||
|
||||
Transport transport;
|
||||
Device device;
|
||||
int ibMaxCqSize;
|
||||
int ibMaxCqPollNum;
|
||||
int ibMaxSendWr;
|
||||
@@ -546,15 +543,18 @@ struct EndpointConfig {
|
||||
/// Constructor that takes a transport and sets the other fields to their default values.
|
||||
///
|
||||
/// @param transport The transport to use.
|
||||
/// @param device The device to use.
|
||||
/// @param ibMaxCqSize The maximum completion queue size.
|
||||
/// @param ibMaxCqPollNum The maximum completion queue poll number.
|
||||
/// @param ibMaxSendWr The maximum send work requests.
|
||||
/// @param ibMaxWrPerSend The maximum work requests per send.
|
||||
/// @param maxWriteQueueSize The maximum write queue size.
|
||||
EndpointConfig(Transport transport = Transport::Unknown, int ibMaxCqSize = DefaultMaxCqSize,
|
||||
int ibMaxCqPollNum = DefaultMaxCqPollNum, int ibMaxSendWr = DefaultMaxSendWr,
|
||||
int ibMaxWrPerSend = DefaultMaxWrPerSend, int maxWriteQueueSize = -1)
|
||||
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
|
||||
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
|
||||
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
|
||||
int maxWriteQueueSize = -1)
|
||||
: transport(transport),
|
||||
device(device),
|
||||
ibMaxCqSize(ibMaxCqSize),
|
||||
ibMaxCqPollNum(ibMaxCqPollNum),
|
||||
ibMaxSendWr(ibMaxSendWr),
|
||||
@@ -562,7 +562,7 @@ struct EndpointConfig {
|
||||
maxWriteQueueSize(maxWriteQueueSize) {}
|
||||
};
|
||||
|
||||
/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
|
||||
/// Context for communication. This provides a low-level interface for forming connections in use-cases
|
||||
/// where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server
|
||||
/// connections. Correct use of this class requires external synchronization when finalizing connections with the
|
||||
/// connect() method.
|
||||
@@ -619,6 +619,62 @@ class Context : public std::enable_shared_from_this<Context> {
|
||||
friend class Endpoint;
|
||||
};
|
||||
|
||||
/// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
|
||||
class SemaphoreStub {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param connection A shared pointer to the connection associated with this semaphore.
|
||||
SemaphoreStub(std::shared_ptr<Connection> connection);
|
||||
|
||||
/// Get the memory associated with this semaphore.
|
||||
/// @return A reference to the registered memory for this semaphore.
|
||||
const RegisteredMemory& memory() const;
|
||||
|
||||
/// Serialize into a vector of characters.
|
||||
/// @return A vector of characters representing the serialized SemaphoreStub object.
|
||||
std::vector<char> serialize() const;
|
||||
|
||||
/// Deserialize a SemaphoreStub object from a vector of characters.
|
||||
/// @param data A vector of characters representing a serialized SemaphoreStub object.
|
||||
/// @return A deserialized SemaphoreStub object.
|
||||
static SemaphoreStub deserialize(const std::vector<char>& data);
|
||||
|
||||
protected:
|
||||
struct Impl;
|
||||
SemaphoreStub(std::shared_ptr<Impl> pimpl);
|
||||
std::shared_ptr<Impl> pimpl_;
|
||||
|
||||
friend class Semaphore;
|
||||
};
|
||||
|
||||
/// Semaphore used by channels for synchronization.
|
||||
class Semaphore {
|
||||
public:
|
||||
/// Constructor.
|
||||
Semaphore() = default;
|
||||
|
||||
/// Constructor.
|
||||
/// @param localStub SemaphoreStub allocated on the local process.
|
||||
/// @param remoteStub SemaphoreStub allocated on the remote process.
|
||||
Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub);
|
||||
|
||||
/// Get the connection associated with this semaphore.
|
||||
/// @return A shared pointer to the connection.
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
|
||||
/// Get the local memory associated with this semaphore.
|
||||
/// @return A reference to the local registered memory.
|
||||
const RegisteredMemory& localMemory() const;
|
||||
|
||||
/// Get the remote memory associated with this semaphore.
|
||||
/// @return A reference to the remote registered memory.
|
||||
const RegisteredMemory& remoteMemory() const;
|
||||
|
||||
protected:
|
||||
struct Impl;
|
||||
std::shared_ptr<Impl> pimpl_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using NonblockingFuture [[deprecated("Use std::shared_future instead. This will be removed in a future release.")]] =
|
||||
std::shared_future<T>;
|
||||
@@ -630,20 +686,26 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
|
||||
/// 2. Call registerMemory() to register memory regions that will be used for communication.
|
||||
/// 3. Call sendMemory() or recvMemory() to send/receive registered memory regions to/from
|
||||
/// other processes.
|
||||
/// 4. Call get() on all futures returned by connect() and recvMemory().
|
||||
/// 5. All done; use connections and registered memories to build channels.
|
||||
/// 4. Call get() on futures returned by connect(). Use the returned connections to create flags.
|
||||
/// 5. Call buildSemaphore() to create a Semaphore out of the flags.
|
||||
/// 6. Call get() on all futures returned by buildSemaphore() and recvMemory().
|
||||
/// 7. All done; use semaphores and registered memories to build channels.
|
||||
///
|
||||
/// CAUTION: The order of method calls matters when the same remote rank and tags are used. That is, the i-th
|
||||
/// "sending" method call (sendMemory(), connect(), and buildSemaphore()) on the local rank must be matched
|
||||
/// by the i-th "receiving" method call (recvMemory(), connect(), and buildSemaphore()) on the remote rank.
|
||||
///
|
||||
/// Correct Example 1:
|
||||
/// ```cpp
|
||||
/// // Rank 0
|
||||
/// communicator.sendMemory(memory1, 1, tag);
|
||||
/// communicator.sendMemory(memory2, 1, tag);
|
||||
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
|
||||
/// connection.get(); // This will return the connection.
|
||||
/// // Rank 1
|
||||
/// auto mem1 = communicator.recvMemory(0, tag);
|
||||
/// auto mem2 = communicator.recvMemory(0, tag);
|
||||
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
|
||||
/// mem2.get(); // This will return memory2.
|
||||
/// connection.get(); // This will return the connection.
|
||||
/// mem1.get(); // This will return memory1.
|
||||
@@ -654,13 +716,13 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
|
||||
/// // Rank 0
|
||||
/// communicator.sendMemory(memory0, 1, tag);
|
||||
/// auto mem1 = communicator.recvMemory(1, tag);
|
||||
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
|
||||
/// connection.get(); // This will return the connection.
|
||||
/// mem1.get(); // This will return memory1.
|
||||
/// // Rank 1
|
||||
/// auto mem0 = communicator.recvMemory(0, tag);
|
||||
/// communicator.sendMemory(memory1, 0, tag);
|
||||
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag);
|
||||
/// mem0.get(); // This will return memory0.
|
||||
/// connection.get(); // This will return the connection.
|
||||
/// ```
|
||||
@@ -670,10 +732,10 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
|
||||
/// // Rank 0
|
||||
/// communicator.sendMemory(memory0, 1, tag);
|
||||
/// auto mem1 = communicator.recvMemory(1, tag);
|
||||
/// auto connection = communicator.connect(1, tag, Transport::CudaIpc);
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 1, tag);
|
||||
/// // Rank 1
|
||||
/// auto mem0 = communicator.recvMemory(0, tag);
|
||||
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc); // undefined behavior
|
||||
/// auto connection = communicator.connect(Transport::CudaIpc, 0, tag); // undefined behavior
|
||||
/// communicator.sendMemory(memory1, 0, tag);
|
||||
/// ```
|
||||
/// In the wrong example, the connection information from rank 1 will be sent to the `mem1` object on rank 0,
|
||||
@@ -691,12 +753,10 @@ class Communicator {
|
||||
~Communicator();
|
||||
|
||||
/// Returns the bootstrap held by this communicator.
|
||||
///
|
||||
/// @return The bootstrap held by this communicator.
|
||||
std::shared_ptr<Bootstrap> bootstrap();
|
||||
|
||||
/// Returns the context held by this communicator.
|
||||
///
|
||||
/// @return The context held by this communicator.
|
||||
std::shared_ptr<Context> context();
|
||||
|
||||
@@ -723,7 +783,7 @@ class Communicator {
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the send.
|
||||
///
|
||||
void sendMemory(RegisteredMemory memory, int remoteRank, int tag);
|
||||
void sendMemory(RegisteredMemory memory, int remoteRank, int tag = 0);
|
||||
|
||||
[[deprecated("Use sendMemory() instead. This will be removed in a future release.")]] void sendMemoryOnSetup(
|
||||
RegisteredMemory memory, int remoteRank, int tag) {
|
||||
@@ -752,7 +812,7 @@ class Communicator {
|
||||
/// @param tag The tag to use for identifying the receive.
|
||||
/// @return A future of registered memory.
|
||||
///
|
||||
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag);
|
||||
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag = 0);
|
||||
|
||||
[[deprecated(
|
||||
"Use recvMemory() instead. This will be removed in a future release.")]] NonblockingFuture<RegisteredMemory>
|
||||
@@ -762,7 +822,7 @@ class Communicator {
|
||||
|
||||
/// Connect to a remote rank.
|
||||
///
|
||||
/// This function will start (but not be waiting for) sending metadata about the local endpoint to the remote rank,
|
||||
/// This function will start (but not wait for) sending metadata about the local endpoint to the remote rank,
|
||||
/// and return a future connection without waiting for the remote rank to respond.
|
||||
/// The connection will be established when the remote rank responds with its own endpoint and the local rank calls
|
||||
/// the first get() on the future.
|
||||
@@ -783,19 +843,30 @@ class Communicator {
|
||||
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
|
||||
/// back to back.
|
||||
///
|
||||
/// @param localConfig The configuration for the local endpoint.
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the send and receive.
|
||||
/// @param localConfig The configuration for the local endpoint.
|
||||
/// @return A future of shared pointer to the connection.
|
||||
///
|
||||
std::shared_future<std::shared_ptr<Connection>> connect(int remoteRank, int tag, EndpointConfig localConfig);
|
||||
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0);
|
||||
|
||||
[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
|
||||
shared_future<std::shared_ptr<Connection>>
|
||||
connect(int remoteRank, int tag, EndpointConfig localConfig);
|
||||
|
||||
[[deprecated("Use connect() instead. This will be removed in a future release.")]] NonblockingFuture<
|
||||
std::shared_ptr<Connection>>
|
||||
connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return connect(remoteRank, tag, localConfig);
|
||||
return connect(localConfig, remoteRank, tag);
|
||||
}
|
||||
|
||||
/// Build a semaphore for cross-process synchronization.
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the operation.
|
||||
/// @return A future of the built semaphore.
|
||||
std::shared_future<Semaphore> buildSemaphore(std::shared_ptr<Connection> connection, int remoteRank, int tag = 0);
|
||||
|
||||
/// Get the remote rank a connection is connected to.
|
||||
///
|
||||
/// @param connection The connection to get the remote rank for.
|
||||
@@ -842,6 +913,10 @@ using PacketPayload = typename T::Payload;
|
||||
|
||||
namespace std {
|
||||
|
||||
std::string to_string(const mscclpp::Transport& transport);
|
||||
|
||||
std::string to_string(const mscclpp::Device& device);
|
||||
|
||||
/// Specialization of the std::hash template for mscclpp::TransportFlags.
|
||||
template <>
|
||||
struct hash<mscclpp::TransportFlags>;
|
||||
|
||||
@@ -13,24 +13,24 @@
|
||||
|
||||
/// Throw mscclpp::CudaError if @p cmd does not return cudaSuccess.
|
||||
/// @param cmd The command to execute.
|
||||
#define MSCCLPP_CUDATHROW(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
throw mscclpp::CudaError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
|
||||
err); \
|
||||
} \
|
||||
#define MSCCLPP_CUDATHROW(cmd) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
throw ::mscclpp::CudaError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
|
||||
err); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
/// Throw mscclpp::CuError if @p cmd does not return CUDA_SUCCESS.
|
||||
/// @param cmd The command to execute.
|
||||
#define MSCCLPP_CUTHROW(cmd) \
|
||||
do { \
|
||||
CUresult err = cmd; \
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
throw mscclpp::CuError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
|
||||
err); \
|
||||
} \
|
||||
#define MSCCLPP_CUTHROW(cmd) \
|
||||
do { \
|
||||
CUresult err = cmd; \
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
throw ::mscclpp::CuError(std::string("Call to " #cmd " failed. ") + __FILE__ + ":" + std::to_string(__LINE__), \
|
||||
err); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -18,13 +18,19 @@ struct BaseMemoryChannel {
|
||||
std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore_;
|
||||
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
BaseMemoryChannel() = default;
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphore The semaphore used to synchronize the communication.
|
||||
/// @param semaphore Semaphore used to synchronize the communication.
|
||||
BaseMemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphore Semaphore used to synchronize the communication.
|
||||
BaseMemoryChannel(const Semaphore& semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param other Other BaseMemoryChannel to copy from.
|
||||
BaseMemoryChannel(const BaseMemoryChannel& other) = default;
|
||||
|
||||
BaseMemoryChannel& operator=(BaseMemoryChannel& other) = default;
|
||||
@@ -33,9 +39,8 @@ struct BaseMemoryChannel {
|
||||
using DeviceHandle = BaseMemoryChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the BaseMemoryChannel is not released when using the returned handle.
|
||||
///
|
||||
/// @return The device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
@@ -47,7 +52,7 @@ struct MemoryChannel : public BaseMemoryChannel {
|
||||
void* packetBuffer_;
|
||||
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
MemoryChannel() = default;
|
||||
|
||||
/// Constructor.
|
||||
@@ -59,13 +64,20 @@ struct MemoryChannel : public BaseMemoryChannel {
|
||||
MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore, RegisteredMemory dst, void* src,
|
||||
void* packetBuffer = nullptr);
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphore The semaphore used to synchronize the communication.
|
||||
/// @param dst Registered memory of the destination.
|
||||
/// @param src The source memory address.
|
||||
/// @param packetBuffer A buffer used to store packets. @p packetBuffer is optional and if it is nullptr,
|
||||
/// unpackPacket() and unpackPackets() methods are not available.
|
||||
MemoryChannel(const Semaphore& semaphore, RegisteredMemory dst, void* src, void* packetBuffer = nullptr);
|
||||
|
||||
/// Device-side handle for MemoryChannel.
|
||||
using DeviceHandle = MemoryChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the MemoryChannel is not released when using the returned handle.
|
||||
///
|
||||
/// @return The device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -35,6 +35,11 @@ class ProxyService : public BaseProxyService {
|
||||
/// @return The ID of the semaphore.
|
||||
SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
|
||||
/// Add a semaphore to the proxy service.
|
||||
/// @param semaphore The semaphore to be added
|
||||
/// @return The ID of the semaphore.
|
||||
SemaphoreId addSemaphore(const Semaphore& semaphore);
|
||||
|
||||
/// Add a semaphore to the proxy service.
|
||||
/// @param semaphore The semaphore to be added
|
||||
/// @return The ID of the semaphore.
|
||||
@@ -87,22 +92,36 @@ struct BasePortChannel {
|
||||
std::shared_ptr<Proxy> proxy_;
|
||||
|
||||
public:
|
||||
/// Constructor.
|
||||
BasePortChannel() = default;
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
/// @param semaphore The semaphore used to synchronize the communication.
|
||||
/// @param proxy The proxy used for communication.
|
||||
BasePortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore,
|
||||
std::shared_ptr<Proxy> proxy);
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
/// @param semaphore The semaphore used to synchronize the communication.
|
||||
/// @param proxy The proxy used for communication.
|
||||
BasePortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore, std::shared_ptr<Proxy> proxy);
|
||||
|
||||
/// Copy constructor.
|
||||
/// @param other The other BasePortChannel to copy from.
|
||||
BasePortChannel(const BasePortChannel& other) = default;
|
||||
|
||||
/// Assignment operator.
|
||||
/// @param other The other BasePortChannel to assign from.
|
||||
BasePortChannel& operator=(BasePortChannel& other) = default;
|
||||
|
||||
/// Device-side handle for BasePortChannel.
|
||||
using DeviceHandle = BasePortChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the BasePortChannel is not released when using the returned handle.
|
||||
///
|
||||
/// @return The device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
@@ -113,7 +132,7 @@ struct PortChannel : public BasePortChannel {
|
||||
MemoryId src_;
|
||||
|
||||
public:
|
||||
/// Default constructor.
|
||||
/// Constructor.
|
||||
PortChannel() = default;
|
||||
|
||||
/// Constructor.
|
||||
@@ -125,19 +144,29 @@ struct PortChannel : public BasePortChannel {
|
||||
PortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore, std::shared_ptr<Proxy> proxy,
|
||||
MemoryId dst, MemoryId src);
|
||||
|
||||
/// Constructor.
|
||||
/// @param semaphoreId The ID of the semaphore.
|
||||
/// @param semaphore The semaphore.
|
||||
/// @param proxy The proxy.
|
||||
/// @param dst The destination memory region.
|
||||
/// @param src The source memory region.
|
||||
PortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore, std::shared_ptr<Proxy> proxy, MemoryId dst,
|
||||
MemoryId src);
|
||||
|
||||
/// Copy constructor.
|
||||
/// @param other The other PortChannel to copy from.
|
||||
PortChannel(const PortChannel& other) = default;
|
||||
|
||||
/// Assignment operator.
|
||||
/// @param other The other PortChannel to assign from.
|
||||
PortChannel& operator=(PortChannel& other) = default;
|
||||
|
||||
/// Device-side handle for PortChannel.
|
||||
using DeviceHandle = PortChannelDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
///
|
||||
/// User should make sure the PortChannel is not released when using the returned handle.
|
||||
///
|
||||
/// @return The device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -12,63 +12,18 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
/// A base class for semaphores.
|
||||
///
|
||||
/// A semaphore is a synchronization mechanism that allows the local peer to wait for the remote peer to complete a
|
||||
/// data transfer. The local peer signals the remote peer that it has completed a data transfer by incrementing the
|
||||
/// outbound semaphore ID. The incremented outbound semaphore ID is copied to the remote peer's inbound semaphore ID so
|
||||
/// that the remote peer can wait for the local peer to complete a data transfer. Vice versa, the remote peer signals
|
||||
/// the local peer that it has completed a data transfer by incrementing the remote peer's outbound semaphore ID and
|
||||
/// copying the incremented value to the local peer's inbound semaphore ID.
|
||||
///
|
||||
/// @tparam InboundDeleter The deleter for inbound semaphore IDs. This is either `std::default_delete` for host memory
|
||||
/// or CudaDeleter for device memory.
|
||||
/// @tparam OutboundDeleter The deleter for outbound semaphore IDs. This is either `std::default_delete` for host memory
|
||||
/// or CudaDeleter for device memory.
|
||||
///
|
||||
template <template <typename> typename InboundDeleter, template <typename> typename OutboundDeleter>
|
||||
class BaseSemaphore {
|
||||
protected:
|
||||
/// The registered memory for the remote peer's inbound semaphore ID.
|
||||
std::shared_future<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;
|
||||
|
||||
/// The inbound semaphore ID that is incremented by the remote peer and waited on by the local peer.
|
||||
///
|
||||
/// The location of localInboundSemaphore_ can be either on the host or on the device.
|
||||
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> localInboundSemaphore_;
|
||||
|
||||
/// The expected inbound semaphore ID to be incremented by the local peer and compared to the
|
||||
/// localInboundSemaphore_.
|
||||
///
|
||||
/// The location of expectedInboundSemaphore_ can be either on the host or on the device.
|
||||
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> expectedInboundSemaphore_;
|
||||
|
||||
/// The outbound semaphore ID that is incremented by the local peer and copied to the remote peer's
|
||||
/// localInboundSemaphore_.
|
||||
///
|
||||
/// The location of outboundSemaphore_ can be either on the host or on the device.
|
||||
std::unique_ptr<uint64_t, OutboundDeleter<uint64_t>> outboundSemaphore_;
|
||||
|
||||
public:
|
||||
/// Constructs a BaseSemaphore.
|
||||
///
|
||||
/// @param localInboundSemaphoreId The inbound semaphore ID
|
||||
/// @param expectedInboundSemaphoreId The expected inbound semaphore ID
|
||||
/// @param outboundSemaphoreId The outbound semaphore ID
|
||||
BaseSemaphore(std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> localInboundSemaphoreId,
|
||||
std::unique_ptr<uint64_t, InboundDeleter<uint64_t>> expectedInboundSemaphoreId,
|
||||
std::unique_ptr<uint64_t, OutboundDeleter<uint64_t>> outboundSemaphoreId)
|
||||
: localInboundSemaphore_(std::move(localInboundSemaphoreId)),
|
||||
expectedInboundSemaphore_(std::move(expectedInboundSemaphoreId)),
|
||||
outboundSemaphore_(std::move(outboundSemaphoreId)) {}
|
||||
};
|
||||
|
||||
/// A semaphore for sending signals from the host to the device.
|
||||
class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::default_delete> {
|
||||
class Host2DeviceSemaphore {
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
Semaphore semaphore_;
|
||||
detail::UniqueGpuPtr<uint64_t> expectedInboundToken_;
|
||||
std::unique_ptr<uint64_t> outboundToken_;
|
||||
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param semaphore
|
||||
Host2DeviceSemaphore(const Semaphore& semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param communicator The communicator.
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
@@ -76,7 +31,7 @@ class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::defau
|
||||
|
||||
/// Returns the connection.
|
||||
/// @return The connection associated with this semaphore.
|
||||
std::shared_ptr<Connection> connection();
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
|
||||
/// Signal the device.
|
||||
void signal();
|
||||
@@ -85,13 +40,22 @@ class Host2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, std::defau
|
||||
using DeviceHandle = Host2DeviceSemaphoreDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
DeviceHandle deviceHandle();
|
||||
DeviceHandle deviceHandle() const;
|
||||
};
|
||||
|
||||
/// A semaphore for sending signals from the local host to a remote host.
|
||||
class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::default_delete> {
|
||||
class Host2HostSemaphore {
|
||||
private:
|
||||
Semaphore semaphore_;
|
||||
std::unique_ptr<uint64_t> expectedInboundToken_;
|
||||
std::unique_ptr<uint64_t> outboundToken_;
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
/// Constructor.
|
||||
/// @param semaphore
|
||||
Host2HostSemaphore(const Semaphore& semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param communicator The communicator.
|
||||
/// @param connection The connection associated with this semaphore. Transport::CudaIpc is not allowed for
|
||||
/// Host2HostSemaphore.
|
||||
@@ -99,7 +63,7 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
|
||||
|
||||
/// Returns the connection.
|
||||
/// @return The connection associated with this semaphore.
|
||||
std::shared_ptr<Connection> connection();
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
|
||||
/// Signal the remote host.
|
||||
void signal();
|
||||
@@ -111,29 +75,34 @@ class Host2HostSemaphore : public BaseSemaphore<std::default_delete, std::defaul
|
||||
/// Wait for the remote host to signal.
|
||||
/// @param maxSpinCount The maximum number of spin counts before throwing an exception. Never throws if negative.
|
||||
void wait(int64_t maxSpinCount = 10000000);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Connection> connection_;
|
||||
};
|
||||
|
||||
/// A semaphore for sending signals from the local device to a peer device via a GPU thread.
|
||||
class MemoryDevice2DeviceSemaphore : public BaseSemaphore<detail::GpuDeleter, detail::GpuDeleter> {
|
||||
class MemoryDevice2DeviceSemaphore {
|
||||
private:
|
||||
Semaphore semaphore_;
|
||||
detail::UniqueGpuPtr<uint64_t> expectedInboundToken_;
|
||||
detail::UniqueGpuPtr<uint64_t> outboundToken_;
|
||||
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param semaphore
|
||||
MemoryDevice2DeviceSemaphore(const Semaphore& semaphore);
|
||||
|
||||
/// Constructor.
|
||||
/// @param communicator The communicator.
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
MemoryDevice2DeviceSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
|
||||
/// Constructor.
|
||||
MemoryDevice2DeviceSemaphore() = delete;
|
||||
/// Returns the connection.
|
||||
/// @return The connection associated with this semaphore.
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
|
||||
/// Device-side handle for MemoryDevice2DeviceSemaphore.
|
||||
using DeviceHandle = MemoryDevice2DeviceSemaphoreDeviceHandle;
|
||||
|
||||
/// Returns the device-side handle.
|
||||
DeviceHandle deviceHandle() const;
|
||||
|
||||
bool isRemoteInboundSemaphoreIdSet_;
|
||||
};
|
||||
|
||||
/// @deprecated Use MemoryDevice2DeviceSemaphore instead.
|
||||
|
||||
@@ -33,28 +33,28 @@ struct Host2DeviceSemaphoreDeviceHandle {
|
||||
/// Thread-safe read of expected inbound value.
|
||||
/// @return The expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
|
||||
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
|
||||
return atomicLoad<uint64_t, scopeDevice>(expectedInboundToken, memoryOrderRelaxed);
|
||||
}
|
||||
|
||||
/// Thread-safe increment of expected inbound value.
|
||||
/// @return The incremented expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundToken, 1, memoryOrderRelaxed) + 1;
|
||||
}
|
||||
|
||||
/// Thread-safe read of inbound value.
|
||||
/// @return The inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
|
||||
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
|
||||
return atomicLoad<uint64_t, scopeSystem>(inboundToken, memoryOrderAcquire);
|
||||
}
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
/// A local memory space where a host thread (on behalf of the remote device) will write its semaphore value
|
||||
/// and the local device will read it.
|
||||
uint64_t* inboundSemaphoreId;
|
||||
uint64_t* inboundToken;
|
||||
|
||||
/// A local memory space where the local device stores the expected value of the inboundSemaphoreId to wait for.
|
||||
uint64_t* expectedInboundSemaphoreId;
|
||||
/// A local memory space where the local device stores the expected value of the inboundToken to wait for.
|
||||
uint64_t* expectedInboundToken;
|
||||
};
|
||||
|
||||
/// Device-side handle for MemoryDevice2DeviceSemaphore.
|
||||
@@ -85,61 +85,61 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle {
|
||||
auto outbound = incOutbound();
|
||||
#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ == 800)
|
||||
// Using memoryOrderSeqCst is faster for A100.
|
||||
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderSeqCst);
|
||||
atomicStore(remoteInboundToken, outbound, memoryOrderSeqCst);
|
||||
#else
|
||||
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelease);
|
||||
atomicStore(remoteInboundToken, outbound, memoryOrderRelease);
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Relaxed signal; no memory completion guarantee. Use it only for synchronizing execution, not data.
|
||||
MSCCLPP_DEVICE_INLINE void relaxedSignal() {
|
||||
auto outbound = incOutbound();
|
||||
atomicStore(remoteInboundSemaphoreId, outbound, memoryOrderRelaxed);
|
||||
atomicStore(remoteInboundToken, outbound, memoryOrderRelaxed);
|
||||
}
|
||||
|
||||
/// Thread-safe read of expected inbound value.
|
||||
/// @return The expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadExpectedInbound() {
|
||||
return atomicLoad<uint64_t, scopeDevice>(expectedInboundSemaphoreId, memoryOrderRelaxed);
|
||||
return atomicLoad<uint64_t, scopeDevice>(expectedInboundToken, memoryOrderRelaxed);
|
||||
}
|
||||
|
||||
/// Thread-safe increment of expected inbound value.
|
||||
/// @return The incremented expected inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t incExpectedInbound() {
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(expectedInboundToken, 1, memoryOrderRelaxed) + 1;
|
||||
}
|
||||
|
||||
/// Thread-safe read of inbound value.
|
||||
/// @return The inbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadInbound() {
|
||||
return atomicLoad<uint64_t, scopeSystem>(inboundSemaphoreId, memoryOrderAcquire);
|
||||
return atomicLoad<uint64_t, scopeSystem>(inboundToken, memoryOrderAcquire);
|
||||
}
|
||||
|
||||
/// Thread-safe read of outbound value.
|
||||
/// @return The outbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t loadOutbound() {
|
||||
return atomicLoad<uint64_t, scopeDevice>(outboundSemaphoreId, memoryOrderRelaxed);
|
||||
return atomicLoad<uint64_t, scopeDevice>(outboundToken, memoryOrderRelaxed);
|
||||
}
|
||||
|
||||
/// Thread-safe increment of outbound value.
|
||||
/// @return The incremented outbound value.
|
||||
MSCCLPP_DEVICE_INLINE uint64_t incOutbound() {
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(outboundSemaphoreId, 1, memoryOrderRelaxed) + 1;
|
||||
return atomicFetchAdd<uint64_t, scopeDevice>(outboundToken, 1, memoryOrderRelaxed) + 1;
|
||||
}
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
/// A local memory space where the remote device will write its semaphore value and the local device will read it.
|
||||
uint64_t* inboundSemaphoreId;
|
||||
uint64_t* inboundToken;
|
||||
|
||||
/// A local memory space where the local device stores the semaphore value to be written to the remote device.
|
||||
uint64_t* outboundSemaphoreId;
|
||||
uint64_t* outboundToken;
|
||||
|
||||
/// A remote memory space where the local device writes its outboundSemaphoreId on. This is inboundSemaphoreId of the
|
||||
/// A remote memory space where the local device writes its outboundToken on. This is inboundToken of the
|
||||
/// remote device.
|
||||
uint64_t* remoteInboundSemaphoreId;
|
||||
uint64_t* remoteInboundToken;
|
||||
|
||||
/// A local memory space where the local device stores the expected value of the inboundSemaphoreId to wait for.
|
||||
uint64_t* expectedInboundSemaphoreId;
|
||||
/// A local memory space where the local device stores the expected value of the inboundToken to wait for.
|
||||
uint64_t* expectedInboundToken;
|
||||
};
|
||||
|
||||
/// @deprecated Use MemoryDevice2DeviceSemaphoreDeviceHandle instead.
|
||||
|
||||
@@ -14,6 +14,8 @@ from ._mscclpp import (
|
||||
CudaError,
|
||||
CuError,
|
||||
IbError,
|
||||
Device,
|
||||
DeviceType,
|
||||
Communicator,
|
||||
Connection,
|
||||
connect_nvls_collective,
|
||||
@@ -43,6 +45,8 @@ from ._mscclpp import (
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Device",
|
||||
"DeviceType",
|
||||
"Communicator",
|
||||
"Connection",
|
||||
"connect_nvls_collective",
|
||||
|
||||
@@ -101,7 +101,7 @@ class CommGroup:
|
||||
if endpoint.transport == Transport.Nvls:
|
||||
return connect_nvls_collective(self.communicator, all_ranks, 2**30)
|
||||
else:
|
||||
connections[rank] = self.communicator.connect(rank, 0, endpoint)
|
||||
connections[rank] = self.communicator.connect(endpoint, rank)
|
||||
connections = {rank: connections[rank].get() for rank in connections}
|
||||
return connections
|
||||
|
||||
@@ -124,8 +124,8 @@ class CommGroup:
|
||||
all_registered_memories[self.my_rank] = local_reg_memory
|
||||
future_memories = {}
|
||||
for rank in connections:
|
||||
self.communicator.send_memory(local_reg_memory, rank, 0)
|
||||
future_memories[rank] = self.communicator.recv_memory(rank, 0)
|
||||
self.communicator.send_memory(local_reg_memory, rank)
|
||||
future_memories[rank] = self.communicator.recv_memory(rank)
|
||||
for rank in connections:
|
||||
all_registered_memories[rank] = future_memories[rank].get()
|
||||
return all_registered_memories
|
||||
|
||||
@@ -112,6 +112,19 @@ void register_core(nb::module_& m) {
|
||||
.def(nb::self == nb::self)
|
||||
.def(nb::self != nb::self);
|
||||
|
||||
nb::enum_<DeviceType>(m, "DeviceType")
|
||||
.value("Unknown", DeviceType::Unknown)
|
||||
.value("CPU", DeviceType::CPU)
|
||||
.value("GPU", DeviceType::GPU);
|
||||
|
||||
nb::class_<Device>(m, "Device")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<DeviceType>(), nb::arg("type"))
|
||||
.def(nb::init<DeviceType, int>(), nb::arg("type"), nb::arg("id") = -1)
|
||||
.def_rw("type", &Device::type)
|
||||
.def_rw("id", &Device::id)
|
||||
.def("__str__", [](const Device& self) { return std::to_string(self); });
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", &RegisteredMemory::data)
|
||||
@@ -120,6 +133,13 @@ void register_core(nb::module_& m) {
|
||||
.def("serialize", &RegisteredMemory::serialize)
|
||||
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("device", &Endpoint::device)
|
||||
.def("max_write_queue_size", &Endpoint::maxWriteQueueSize)
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Connection>(m, "Connection")
|
||||
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
|
||||
nb::arg("size"))
|
||||
@@ -131,21 +151,26 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
|
||||
.def("flush", &Connection::flush, nb::call_guard<nb::gil_scoped_release>(), nb::arg("timeoutUsec") = (int64_t)3e7)
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport);
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
|
||||
.def("remote_transport", &Connection::remoteTransport)
|
||||
.def("context", &Connection::context)
|
||||
.def("local_device", &Connection::localDevice)
|
||||
.def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize);
|
||||
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, Device, int, int, int, int, int>(), nb::arg("transport"), nb::arg("device"),
|
||||
nb::arg("ibMaxCqSize") = EndpointConfig::DefaultMaxCqSize,
|
||||
nb::arg("ibMaxCqPollNum") = EndpointConfig::DefaultMaxCqPollNum,
|
||||
nb::arg("ibMaxSendWr") = EndpointConfig::DefaultMaxSendWr,
|
||||
nb::arg("ibMaxWrPerSend") = EndpointConfig::DefaultMaxWrPerSend, nb::arg("maxWriteQueueSize") = -1)
|
||||
.def_rw("transport", &EndpointConfig::transport)
|
||||
.def_rw("device", &EndpointConfig::device)
|
||||
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
|
||||
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
|
||||
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
|
||||
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend);
|
||||
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend)
|
||||
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
|
||||
|
||||
nb::class_<Context>(m, "Context")
|
||||
.def_static("create", &Context::create)
|
||||
@@ -158,6 +183,19 @@ void register_core(nb::module_& m) {
|
||||
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
|
||||
.def(nb::init<std::shared_ptr<Connection>>(), nb::arg("connection"))
|
||||
.def("memory", &SemaphoreStub::memory)
|
||||
.def("serialize", &SemaphoreStub::serialize)
|
||||
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Semaphore>(m, "Semaphore")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("localStub"), nb::arg("remoteStub"))
|
||||
.def("connection", &Semaphore::connection)
|
||||
.def("local_memory", &Semaphore::localMemory)
|
||||
.def("remote_memory", &Semaphore::remoteMemory);
|
||||
|
||||
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_shared_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
|
||||
@@ -172,12 +210,28 @@ void register_core(nb::module_& m) {
|
||||
return self->registerMemory((void*)ptr, size, transports);
|
||||
},
|
||||
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("connect",
|
||||
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
|
||||
&Communicator::connect),
|
||||
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def(
|
||||
"connect",
|
||||
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return self->connect(std::move(localConfig), remoteRank, tag);
|
||||
},
|
||||
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def(
|
||||
"connect_on_setup",
|
||||
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return self->connect(std::move(localConfig), remoteRank, tag);
|
||||
},
|
||||
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("localFlag"), nb::arg("remoteRank"),
|
||||
nb::arg("tag") = 0)
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
.def("setup", [](Communicator*) {});
|
||||
|
||||
@@ -11,12 +11,10 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_memory_channel(nb::module_& m) {
|
||||
nb::class_<BaseMemoryChannel> baseMemoryChannel(m, "BaseMemoryChannel");
|
||||
baseMemoryChannel
|
||||
.def("__init__",
|
||||
[](BaseMemoryChannel* baseMemoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore) {
|
||||
new (baseMemoryChannel) BaseMemoryChannel(semaphore);
|
||||
})
|
||||
nb::class_<BaseMemoryChannel>(m, "BaseMemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<std::shared_ptr<MemoryDevice2DeviceSemaphore>>(), nb::arg("semaphore"))
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def("device_handle", &BaseMemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "BaseMemoryChannelDeviceHandle")
|
||||
@@ -26,8 +24,8 @@ void register_memory_channel(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<MemoryChannel> memoryChannel(m, "MemoryChannel");
|
||||
memoryChannel
|
||||
nb::class_<MemoryChannel>(m, "MemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst,
|
||||
|
||||
@@ -20,13 +20,19 @@ void register_port_channel(nb::module_& m) {
|
||||
.def("start_proxy", &ProxyService::startProxy)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
|
||||
.def("add_semaphore", static_cast<SemaphoreId (ProxyService::*)(const Semaphore&)>(&ProxyService::addSemaphore),
|
||||
nb::arg("semaphore"))
|
||||
.def("add_semaphore",
|
||||
static_cast<SemaphoreId (ProxyService::*)(std::shared_ptr<Host2DeviceSemaphore>)>(
|
||||
&ProxyService::addSemaphore),
|
||||
nb::arg("semaphore"))
|
||||
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
|
||||
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
|
||||
.def("base_port_channel", &ProxyService::basePortChannel, nb::arg("id"))
|
||||
.def("port_channel", &ProxyService::portChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
|
||||
|
||||
nb::class_<BasePortChannel>(m, "BasePortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
|
||||
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
|
||||
.def("device_handle", &BasePortChannel::deviceHandle);
|
||||
@@ -41,6 +47,7 @@ void register_port_channel(nb::module_& m) {
|
||||
});
|
||||
|
||||
nb::class_<PortChannel>(m, "PortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
|
||||
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
|
||||
.def("device_handle", &PortChannel::deviceHandle);
|
||||
|
||||
@@ -11,7 +11,7 @@ using namespace mscclpp;
|
||||
|
||||
void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
|
||||
host2DeviceSemaphore
|
||||
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2DeviceSemaphore::connection)
|
||||
.def("signal", &Host2DeviceSemaphore::signal)
|
||||
@@ -19,13 +19,14 @@ void register_semaphore(nb::module_& m) {
|
||||
|
||||
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_rw("inbound_token", &Host2DeviceSemaphore::DeviceHandle::inboundToken)
|
||||
.def_rw("expected_inbound_token", &Host2DeviceSemaphore::DeviceHandle::expectedInboundToken)
|
||||
.def_prop_ro("raw", [](const Host2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
.def("signal", &Host2HostSemaphore::signal)
|
||||
@@ -34,16 +35,17 @@ void register_semaphore(nb::module_& m) {
|
||||
nb::arg("max_spin_count") = 10000000);
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
|
||||
memoryDevice2DeviceSemaphore
|
||||
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
|
||||
.def("device_handle", &MemoryDevice2DeviceSemaphore::deviceHandle);
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("outboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
|
||||
.def_rw("remoteInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
|
||||
.def_rw("expectedInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_rw("inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundToken)
|
||||
.def_rw("outbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundToken)
|
||||
.def_rw("remote_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundToken)
|
||||
.def_rw("expected_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundToken)
|
||||
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
@@ -28,6 +28,8 @@ from mscclpp import (
|
||||
is_nvls_supported,
|
||||
npkit,
|
||||
env,
|
||||
Device,
|
||||
DeviceType,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
@@ -280,7 +282,13 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
group, connections = create_group_and_connection(mpi_group, "IB")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
tran = group.my_ib_device(group.my_rank % 8)
|
||||
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
|
||||
remote_nghrs = list(range(group.nranks))
|
||||
remote_nghrs.remove(group.my_rank)
|
||||
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
|
||||
connections = {rank: conn.get() for rank, conn in connections.items()}
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
for rank in connections:
|
||||
@@ -293,7 +301,13 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
group, connections = create_group_and_connection(mpi_group, "IB")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
tran = group.my_ib_device(group.my_rank % 8)
|
||||
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
|
||||
remote_nghrs = list(range(group.nranks))
|
||||
remote_nghrs.remove(group.my_rank)
|
||||
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
|
||||
connections = {rank: conn.get() for rank, conn in connections.items()}
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ const char* SocketToString(union SocketAddress* addr, char* buf, const int numer
|
||||
static int getTcpFinTimeout() {
|
||||
std::ifstream ifs("/proc/sys/net/ipv4/tcp_fin_timeout");
|
||||
if (!ifs.is_open()) {
|
||||
throw mscclpp::SysError("open /proc/sys/net/ipv4/tcp_fin_timeout failed", errno);
|
||||
throw SysError("open /proc/sys/net/ipv4/tcp_fin_timeout failed", errno);
|
||||
}
|
||||
int timeout;
|
||||
ifs >> timeout;
|
||||
@@ -80,12 +80,12 @@ static int findInterfaces(const char* prefixList, char* names, union SocketAddre
|
||||
#ifdef MSCCLPP_ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
#endif
|
||||
struct mscclpp::netIf userIfs[MAX_IFS];
|
||||
struct netIf userIfs[MAX_IFS];
|
||||
bool searchNot = prefixList && prefixList[0] == '^';
|
||||
if (searchNot) prefixList++;
|
||||
bool searchExact = prefixList && prefixList[0] == '=';
|
||||
if (searchExact) prefixList++;
|
||||
int nUserIfs = mscclpp::parseStringList(prefixList, userIfs, MAX_IFS);
|
||||
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
|
||||
|
||||
int found = 0;
|
||||
struct ifaddrs *interfaces, *interface;
|
||||
@@ -110,7 +110,7 @@ static int findInterfaces(const char* prefixList, char* names, union SocketAddre
|
||||
}
|
||||
|
||||
// check against user specified interfaces
|
||||
if (!(mscclpp::matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
|
||||
if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -224,16 +224,16 @@ int FindInterfaceMatchSubnet(char* ifNames, union SocketAddress* localAddrs, uni
|
||||
|
||||
void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair) {
|
||||
if (!(ip_port_pair && strlen(ip_port_pair) > 1)) {
|
||||
throw mscclpp::Error("Net : string is null", mscclpp::ErrorCode::InvalidUsage);
|
||||
throw Error("Net : string is null", ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
bool ipv6 = ip_port_pair[0] == '[';
|
||||
/* Construct the sockaddress structure */
|
||||
if (!ipv6) {
|
||||
struct mscclpp::netIf ni;
|
||||
struct netIf ni;
|
||||
// parse <ip_or_hostname>:<port> string, expect one pair
|
||||
if (mscclpp::parseStringList(ip_port_pair, &ni, 1) != 1) {
|
||||
throw mscclpp::Error("Net : No valid <IPv4_or_hostname>:<port> pair found", mscclpp::ErrorCode::InvalidUsage);
|
||||
if (parseStringList(ip_port_pair, &ni, 1) != 1) {
|
||||
throw Error("Net : No valid <IPv4_or_hostname>:<port> pair found", ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
struct addrinfo hints, *p;
|
||||
@@ -245,7 +245,7 @@ void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair)
|
||||
if ((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "Net : error encountered when getting address info : " << gai_strerror(rv);
|
||||
throw mscclpp::Error(ss.str(), mscclpp::ErrorCode::InvalidUsage);
|
||||
throw Error(ss.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
// use the first
|
||||
@@ -263,7 +263,7 @@ void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair)
|
||||
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
|
||||
sin6.sin6_scope_id = 0; // should be global scope, set to 0
|
||||
} else {
|
||||
throw mscclpp::Error("Net : unsupported IP family", mscclpp::ErrorCode::InvalidUsage);
|
||||
throw Error("Net : unsupported IP family", ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
freeaddrinfo(p); // all done with this structure
|
||||
@@ -276,7 +276,7 @@ void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair)
|
||||
}
|
||||
if (i == len) {
|
||||
WARN("Net : No valid [IPv6]:port pair found");
|
||||
throw mscclpp::Error("Net : No valid [IPv6]:port pair found", mscclpp::ErrorCode::InvalidUsage);
|
||||
throw Error("Net : No valid [IPv6]:port pair found", ErrorCode::InvalidUsage);
|
||||
}
|
||||
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
|
||||
|
||||
@@ -448,7 +448,7 @@ void Socket::bindAndListen() {
|
||||
}
|
||||
|
||||
void Socket::connect(int64_t timeout) {
|
||||
mscclpp::Timer timer;
|
||||
Timer timer;
|
||||
#ifdef MSCCLPP_ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
#endif
|
||||
@@ -483,7 +483,7 @@ void Socket::connect(int64_t timeout) {
|
||||
}
|
||||
|
||||
void Socket::accept(const Socket* listenSocket, int64_t timeout) {
|
||||
mscclpp::Timer timer;
|
||||
Timer timer;
|
||||
|
||||
if (listenSocket == NULL) {
|
||||
throw Error("listenSocket is NULL", ErrorCode::InvalidUsage);
|
||||
|
||||
@@ -99,8 +99,8 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
|
||||
EndpointConfig localConfig) {
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(EndpointConfig localConfig,
|
||||
int remoteRank, int tag) {
|
||||
auto localEndpoint = context()->createEndpoint(localConfig);
|
||||
|
||||
if (remoteRank == bootstrap()->getRank()) {
|
||||
@@ -134,6 +134,33 @@ MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::co
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
|
||||
EndpointConfig localConfig) {
|
||||
return connect(localConfig, remoteRank, tag);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<Semaphore> Communicator::buildSemaphore(std::shared_ptr<Connection> connection,
|
||||
int remoteRank, int tag) {
|
||||
SemaphoreStub localStub(connection);
|
||||
bootstrap()->send(localStub.serialize(), remoteRank, tag);
|
||||
|
||||
auto future =
|
||||
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
|
||||
localStub = localStub]() mutable {
|
||||
if (lastRecvItem) {
|
||||
// Recursive call to the previous receive items
|
||||
lastRecvItem->wait();
|
||||
}
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
auto remoteStub = SemaphoreStub::deserialize(data);
|
||||
return Semaphore(localStub, remoteStub);
|
||||
});
|
||||
auto shared_future = std::shared_future<Semaphore>(std::move(future));
|
||||
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<Semaphore>>(shared_future));
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
|
||||
return pimpl_->connectionInfos_.at(&connection).remoteRank;
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "api.h"
|
||||
#include "debug.h"
|
||||
#include "endpoint.hpp"
|
||||
|
||||
@@ -32,32 +33,21 @@ std::shared_ptr<RegisteredMemory::Impl> Connection::getImpl(RegisteredMemory& me
|
||||
|
||||
std::shared_ptr<Endpoint::Impl> Connection::getImpl(Endpoint& memory) { return memory.pimpl_; }
|
||||
|
||||
std::string Connection::getTransportName() const {
|
||||
return TransportNames[static_cast<int>(this->transport())] + " -> " +
|
||||
TransportNames[static_cast<int>(this->remoteTransport())];
|
||||
}
|
||||
MSCCLPP_API_CPP const Device& Connection::localDevice() const { return localEndpoint_.device(); }
|
||||
|
||||
int Connection::getMaxWriteQueueSize() const { return maxWriteQueueSize_; }
|
||||
MSCCLPP_API_CPP int Connection::getMaxWriteQueueSize() const { return maxWriteQueueSize_; }
|
||||
|
||||
// CudaIpcConnection
|
||||
|
||||
CudaIpcConnection::CudaIpcConnection(std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint,
|
||||
std::shared_ptr<CudaIpcStream> stream)
|
||||
: Connection(context, localEndpoint.maxWriteQueueSize()), stream_(stream) {
|
||||
if (localEndpoint.transport() != Transport::CudaIpc) {
|
||||
throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage);
|
||||
: Connection(context, localEndpoint), stream_(stream) {
|
||||
if (localEndpoint.transport() != Transport::CudaIpc || remoteEndpoint.transport() != Transport::CudaIpc) {
|
||||
throw Error("CudaIpc transport is required for CudaIpcConnection", ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (remoteEndpoint.transport() != Transport::CudaIpc) {
|
||||
throw mscclpp::Error("Cuda IPC connection can only be made to a Cuda IPC endpoint", ErrorCode::InvalidUsage);
|
||||
if (localEndpoint.device().type != DeviceType::GPU && remoteEndpoint.device().type != DeviceType::GPU) {
|
||||
throw Error("CudaIpcConnection requires at least one GPU endpoint", ErrorCode::InvalidUsage);
|
||||
}
|
||||
// sanity check: make sure the IPC connection is being made within a node
|
||||
if (getImpl(remoteEndpoint)->hostHash_ != getImpl(localEndpoint)->hostHash_) {
|
||||
std::stringstream ss;
|
||||
ss << "Cuda IPC connection can only be made within a node: " << std::hex << getImpl(remoteEndpoint)->hostHash_
|
||||
<< " != " << std::hex << getImpl(localEndpoint)->hostHash_;
|
||||
throw mscclpp::Error(ss.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
INFO(MSCCLPP_P2P, "Cuda IPC connection created");
|
||||
}
|
||||
|
||||
Transport CudaIpcConnection::transport() const { return Transport::CudaIpc; }
|
||||
@@ -126,11 +116,13 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
|
||||
// IBConnection
|
||||
|
||||
IBConnection::IBConnection(std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint)
|
||||
: Connection(context, localEndpoint.maxWriteQueueSize() != -1 ? localEndpoint.maxWriteQueueSize()
|
||||
: EndpointConfig::DefaultMaxCqSize),
|
||||
: Connection(context, localEndpoint),
|
||||
transport_(localEndpoint.transport()),
|
||||
remoteTransport_(remoteEndpoint.transport()),
|
||||
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
|
||||
if (maxWriteQueueSize_ == -1) {
|
||||
maxWriteQueueSize_ = EndpointConfig::DefaultMaxCqSize;
|
||||
}
|
||||
qp_ = getImpl(localEndpoint)->ibQp_;
|
||||
qp_->rtr(getImpl(remoteEndpoint)->ibQpInfo_);
|
||||
qp_->rts();
|
||||
@@ -240,13 +232,13 @@ void IBConnection::flush(int64_t timeoutUsec) {
|
||||
|
||||
EthernetConnection::EthernetConnection(std::shared_ptr<Context> context, Endpoint localEndpoint,
|
||||
Endpoint remoteEndpoint, uint64_t sendBufferSize, uint64_t recvBufferSize)
|
||||
: Connection(context, localEndpoint.maxWriteQueueSize()),
|
||||
: Connection(context, localEndpoint),
|
||||
abortFlag_(0),
|
||||
sendBufferSize_(sendBufferSize),
|
||||
recvBufferSize_(recvBufferSize) {
|
||||
// Validating Transport Protocol
|
||||
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
|
||||
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
|
||||
throw Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
// Instanciating Buffers
|
||||
|
||||
@@ -13,10 +13,14 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
CudaIpcStream::CudaIpcStream() : stream_(std::make_shared<CudaStreamWithFlags>()), dirty_(false) {}
|
||||
CudaIpcStream::CudaIpcStream(int deviceId)
|
||||
: stream_(std::make_shared<CudaStreamWithFlags>()), deviceId_(deviceId), dirty_(false) {}
|
||||
|
||||
void CudaIpcStream::setStreamIfNeeded() {
|
||||
if (!env()->cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);
|
||||
if (!env()->cudaIpcUseDefaultStream && stream_->empty()) {
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(deviceId_));
|
||||
stream_->set(cudaStreamNonBlocking);
|
||||
}
|
||||
}
|
||||
|
||||
void CudaIpcStream::memcpyD2D(void *dst, const void *src, size_t nbytes) {
|
||||
@@ -68,29 +72,40 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpo
|
||||
std::shared_ptr<Connection> conn;
|
||||
if (localEndpoint.transport() == Transport::CudaIpc) {
|
||||
if (remoteEndpoint.transport() != Transport::CudaIpc) {
|
||||
throw mscclpp::Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
|
||||
throw Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
int deviceId;
|
||||
if (localEndpoint.device().type == DeviceType::GPU) {
|
||||
deviceId = localEndpoint.device().id;
|
||||
} else if (remoteEndpoint.device().type == DeviceType::GPU) {
|
||||
deviceId = remoteEndpoint.device().id;
|
||||
} else {
|
||||
throw Error("CudaIpc transport requires at least one GPU device", ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (deviceId < 0) {
|
||||
throw Error("No GPU device ID provided", ErrorCode::InvalidUsage);
|
||||
}
|
||||
#if defined(MSCCLPP_DEVICE_HIP)
|
||||
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>());
|
||||
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>(deviceId));
|
||||
#else // !defined(MSCCLPP_DEVICE_HIP)
|
||||
if (pimpl_->ipcStreams_.empty()) {
|
||||
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>());
|
||||
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaIpcStream>(deviceId));
|
||||
}
|
||||
#endif // !defined(MSCCLPP_DEVICE_HIP)
|
||||
conn = std::make_shared<CudaIpcConnection>(shared_from_this(), localEndpoint, remoteEndpoint,
|
||||
pimpl_->ipcStreams_.back());
|
||||
} else if (AllIBTransports.has(localEndpoint.transport())) {
|
||||
if (!AllIBTransports.has(remoteEndpoint.transport())) {
|
||||
throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
|
||||
throw Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
conn = std::make_shared<IBConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
|
||||
} else if (localEndpoint.transport() == Transport::Ethernet) {
|
||||
if (remoteEndpoint.transport() != Transport::Ethernet) {
|
||||
throw mscclpp::Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
|
||||
throw Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
conn = std::make_shared<EthernetConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
|
||||
} else {
|
||||
throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError);
|
||||
throw Error("Unsupported transport", ErrorCode::InternalError);
|
||||
}
|
||||
return conn;
|
||||
}
|
||||
|
||||
12
src/core.cc
12
src/core.cc
@@ -93,6 +93,18 @@ const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc | Tran
|
||||
|
||||
namespace std {
|
||||
|
||||
std::string to_string(const mscclpp::Transport& transport) {
|
||||
static const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
|
||||
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};
|
||||
return TransportNames[static_cast<size_t>(transport)];
|
||||
}
|
||||
|
||||
std::string to_string(const mscclpp::Device& device) {
|
||||
std::stringstream ss;
|
||||
ss << "Device(type=" << to_string(device.type) << ", id=" << device.id << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
struct hash<mscclpp::TransportFlags> {
|
||||
size_t operator()(const mscclpp::TransportFlags& flags) const {
|
||||
|
||||
@@ -7,13 +7,20 @@
|
||||
|
||||
#include "api.h"
|
||||
#include "context.hpp"
|
||||
#include "serialization.hpp"
|
||||
#include "socket.h"
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
|
||||
: transport_(config.transport), hostHash_(getHostHash()), maxWriteQueueSize_(config.maxWriteQueueSize) {
|
||||
: transport_(config.transport),
|
||||
device_(config.device),
|
||||
hostHash_(getHostHash()),
|
||||
maxWriteQueueSize_(config.maxWriteQueueSize) {
|
||||
if (device_.type == DeviceType::GPU && device_.id < 0) {
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&(device_.id)));
|
||||
}
|
||||
if (AllIBTransports.has(transport_)) {
|
||||
ibLocal_ = true;
|
||||
ibQp_ = contextImpl.getIbContext(transport_)
|
||||
@@ -23,7 +30,7 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
|
||||
// Configuring Ethernet Interfaces
|
||||
abortFlag_ = 0;
|
||||
int ret = FindInterfaces(netIfName_, &socketAddress_, MAX_IF_NAME_SIZE, 1);
|
||||
if (ret <= 0) throw Error("NET/Socket", ErrorCode::InternalError);
|
||||
if (ret <= 0) throw Error("Failed to find network interfaces", ErrorCode::InternalError);
|
||||
|
||||
// Starting Server Socket
|
||||
socket_ = std::make_unique<Socket>(&socketAddress_, MSCCLPP_SOCKET_MAGIC, SocketTypeBootstrap, abortFlag_);
|
||||
@@ -32,20 +39,38 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; }
|
||||
Endpoint::Impl::Impl(const std::vector<char>& serialization) {
|
||||
auto it = serialization.begin();
|
||||
it = detail::deserialize(it, transport_);
|
||||
it = detail::deserialize(it, device_);
|
||||
it = detail::deserialize(it, hostHash_);
|
||||
if (AllIBTransports.has(transport_)) {
|
||||
ibLocal_ = false;
|
||||
it = detail::deserialize(it, ibQpInfo_);
|
||||
}
|
||||
if (transport_ == Transport::Ethernet) {
|
||||
it = detail::deserialize(it, socketAddress_);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() { return pimpl_->maxWriteQueueSize_; }
|
||||
MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<Endpoint::Impl> pimpl) : pimpl_(pimpl) {}
|
||||
|
||||
MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
|
||||
MSCCLPP_API_CPP Transport Endpoint::transport() const { return pimpl_->transport_; }
|
||||
|
||||
MSCCLPP_API_CPP const Device& Endpoint::device() const { return pimpl_->device_; }
|
||||
|
||||
MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() const { return pimpl_->maxWriteQueueSize_; }
|
||||
|
||||
MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() const {
|
||||
std::vector<char> data;
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->transport_), sizeof(pimpl_->transport_), std::back_inserter(data));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->hostHash_), sizeof(pimpl_->hostHash_), std::back_inserter(data));
|
||||
detail::serialize(data, pimpl_->transport_);
|
||||
detail::serialize(data, pimpl_->device_);
|
||||
detail::serialize(data, pimpl_->hostHash_);
|
||||
if (AllIBTransports.has(pimpl_->transport_)) {
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->ibQpInfo_), sizeof(pimpl_->ibQpInfo_), std::back_inserter(data));
|
||||
detail::serialize(data, pimpl_->ibQpInfo_);
|
||||
}
|
||||
if ((pimpl_->transport_) == Transport::Ethernet) {
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->socketAddress_), sizeof(pimpl_->socketAddress_),
|
||||
std::back_inserter(data));
|
||||
detail::serialize(data, pimpl_->socketAddress_);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
@@ -54,23 +79,4 @@ MSCCLPP_API_CPP Endpoint Endpoint::deserialize(const std::vector<char>& data) {
|
||||
return Endpoint(std::make_shared<Impl>(data));
|
||||
}
|
||||
|
||||
Endpoint::Impl::Impl(const std::vector<char>& serialization) {
|
||||
auto it = serialization.begin();
|
||||
std::copy_n(it, sizeof(transport_), reinterpret_cast<char*>(&transport_));
|
||||
it += sizeof(transport_);
|
||||
std::copy_n(it, sizeof(hostHash_), reinterpret_cast<char*>(&hostHash_));
|
||||
it += sizeof(hostHash_);
|
||||
if (AllIBTransports.has(transport_)) {
|
||||
ibLocal_ = false;
|
||||
std::copy_n(it, sizeof(ibQpInfo_), reinterpret_cast<char*>(&ibQpInfo_));
|
||||
it += sizeof(ibQpInfo_);
|
||||
}
|
||||
if (transport_ == Transport::Ethernet) {
|
||||
std::copy_n(it, sizeof(socketAddress_), reinterpret_cast<char*>(&socketAddress_));
|
||||
it += sizeof(socketAddress_);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<mscclpp::Endpoint::Impl> pimpl) : pimpl_(pimpl) {}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -185,7 +185,7 @@ size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, siz
|
||||
else if (this->outputChunks.at(rank) != 0)
|
||||
sizePerRank = outputSize / this->outputChunks.at(rank);
|
||||
else
|
||||
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
|
||||
throw Error("Output or Input chunks must be greater than 0", ErrorCode::ExecutorError);
|
||||
|
||||
if (this->isUsingPacket) {
|
||||
return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/;
|
||||
@@ -203,7 +203,7 @@ size_t ExecutionPlan::Impl::getMaxScratchBufferSize(int rank) const {
|
||||
else if (this->outputChunks.at(rank) != 0)
|
||||
sizePerChunk = maxMessageSize / this->outputChunks.at(rank);
|
||||
else
|
||||
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
|
||||
throw Error("Output or Input chunks must be greater than 0", ErrorCode::ExecutorError);
|
||||
|
||||
return this->getScratchBufferSize(rank, sizePerChunk * this->inputChunks.at(rank),
|
||||
sizePerChunk * this->outputChunks.at(rank));
|
||||
@@ -414,12 +414,12 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
|
||||
for (const auto& op : threadblock["ops"]) {
|
||||
Operation operation = {};
|
||||
std::vector<uint32_t> chunkIndexes;
|
||||
operation.type = static_cast<mscclpp::OperationType>(getOpType(op["name"]));
|
||||
operation.type = static_cast<OperationType>(getOpType(op["name"]));
|
||||
if (op.contains("ctype")) {
|
||||
operation.channelType = convertToChannelType(op["ctype"]);
|
||||
}
|
||||
if (op.contains("i_cids")) {
|
||||
if (operation.channelType == mscclpp::ChannelType::NVLS) {
|
||||
if (operation.channelType == ChannelType::NVLS) {
|
||||
BufferType srcBufferType = convertToBufferType(op["srcbuff"]);
|
||||
operation.nvlsInputIndex =
|
||||
channelIndexes[{srcBufferType, srcBufferType, ChannelType::NVLS}][op["i_cids"][0]["id"]];
|
||||
@@ -453,7 +453,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
|
||||
if (op.contains("o_cids")) {
|
||||
operation.nOutputs = op["o_cids"].size();
|
||||
for (int i = 0; i < operation.nOutputs; i++) {
|
||||
if (operation.channelType == mscclpp::ChannelType::NVLS) {
|
||||
if (operation.channelType == ChannelType::NVLS) {
|
||||
BufferType dstBufferType = convertToBufferType(op["dstbuff"]);
|
||||
operation.nvlsOutputIndex =
|
||||
channelIndexes[{dstBufferType, dstBufferType, ChannelType::NVLS}][op["o_cids"][0]["id"]];
|
||||
@@ -516,14 +516,13 @@ std::pair<size_t, uint32_t> ExecutionPlan::Impl::getSizeAndChunksForRank(int ran
|
||||
size_t outputSize) const {
|
||||
std::pair<size_t, uint32_t> sizePerRank;
|
||||
if (this->inputChunks.at(rank) == 0 && this->outputChunks.at(rank) == 0) {
|
||||
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
|
||||
throw Error("Output or Input chunks must be greater than 0", ErrorCode::ExecutorError);
|
||||
} else if (this->inputChunks.at(rank) != 0 && this->outputChunks.at(rank) != 0) {
|
||||
if (inputSize / this->inputChunks.at(rank) != outputSize / this->outputChunks.at(rank))
|
||||
throw mscclpp::Error("Size per chunks inconsistent: inputSize " + std::to_string(inputSize) + " inputChunks " +
|
||||
std::to_string(this->inputChunks.at(rank)) + " outputSize " +
|
||||
std::to_string(outputSize) + " outputChunks " +
|
||||
std::to_string(this->outputChunks.at(rank)),
|
||||
mscclpp::ErrorCode::ExecutorError);
|
||||
throw Error("Size per chunks inconsistent: inputSize " + std::to_string(inputSize) + " inputChunks " +
|
||||
std::to_string(this->inputChunks.at(rank)) + " outputSize " + std::to_string(outputSize) +
|
||||
" outputChunks " + std::to_string(this->outputChunks.at(rank)),
|
||||
ErrorCode::ExecutorError);
|
||||
else
|
||||
sizePerRank = std::make_pair(inputSize, this->inputChunks.at(rank));
|
||||
} else if (this->inputChunks.at(rank) != 0) {
|
||||
|
||||
@@ -216,7 +216,7 @@ struct Executor::Impl {
|
||||
for (int peer : connectedPeers) {
|
||||
Transport transport =
|
||||
inSameNode(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
|
||||
connectionFutures.push_back(this->comm->connect(peer, 0, transport));
|
||||
connectionFutures.push_back(this->comm->connect(transport, peer));
|
||||
}
|
||||
for (size_t i = 0; i < connectionFutures.size(); i++) {
|
||||
context.connections[connectedPeers[i]] = connectionFutures[i].get();
|
||||
@@ -263,12 +263,12 @@ struct Executor::Impl {
|
||||
std::vector<int> connectedPeers = getConnectedPeers(channelInfos);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
|
||||
for (int peer : connectedPeers) {
|
||||
comm->sendMemory(memory, peer, 0);
|
||||
comm->sendMemory(memory, peer);
|
||||
}
|
||||
channelInfos = plan.impl_->getChannelInfos(rank, bufferType);
|
||||
connectedPeers = getConnectedPeers(channelInfos);
|
||||
for (int peer : connectedPeers) {
|
||||
remoteRegMemoryFutures.push_back(comm->recvMemory(peer, 0));
|
||||
remoteRegMemoryFutures.push_back(comm->recvMemory(peer));
|
||||
}
|
||||
for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) {
|
||||
context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get());
|
||||
|
||||
@@ -242,8 +242,8 @@ bool isCuMemMapAllocated([[maybe_unused]] void* ptr) {
|
||||
return false;
|
||||
}
|
||||
MSCCLPP_CUTHROW(cuMemRelease(handle));
|
||||
if (!mscclpp::isNvlsSupported()) {
|
||||
throw mscclpp::Error("cuMemMap is used in env without NVLS support", mscclpp::ErrorCode::InvalidUsage);
|
||||
if (!isNvlsSupported()) {
|
||||
throw Error("cuMemMap is used in env without NVLS support", ErrorCode::InvalidUsage);
|
||||
}
|
||||
return true;
|
||||
#endif
|
||||
|
||||
41
src/ib.cc
41
src/ib.cc
@@ -52,7 +52,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
|
||||
|
||||
CUdeviceptr dptr = reinterpret_cast<CUdeviceptr>(buff);
|
||||
bool cuMemAlloc = mscclpp::isCuMemMapAllocated((void*)dptr);
|
||||
bool cuMemAlloc = isCuMemMapAllocated((void*)dptr);
|
||||
int dmaBufSupported = 0;
|
||||
#if !defined(__HIP_PLATFORM_AMD__)
|
||||
CUdevice dev;
|
||||
@@ -71,11 +71,10 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
if (this->mr == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_reg_dmabuf_mr failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
#else
|
||||
throw mscclpp::Error("Registeration of dma-buf based memory region failed on HIP platform",
|
||||
ErrorCode::InvalidUsage);
|
||||
throw Error("Registeration of dma-buf based memory region failed on HIP platform", ErrorCode::InvalidUsage);
|
||||
#endif // !defined(__HIP_PLATFORM_AMD__)
|
||||
} else {
|
||||
this->mr = IBVerbs::ibv_reg_mr2(pd, reinterpret_cast<void*>(addr), pages * pageSize,
|
||||
@@ -84,7 +83,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
if (this->mr == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_reg_mr failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,7 +110,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
|
||||
if (this->cq == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_cq failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
|
||||
struct ibv_qp_init_attr qpInitAttr;
|
||||
@@ -130,14 +129,14 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
|
||||
if (_qp == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_create_qp failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
|
||||
struct ibv_port_attr portAttr;
|
||||
if (IBVerbs::ibv_query_port_w(ctx, port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
this->info.lid = portAttr.lid;
|
||||
this->info.port = port;
|
||||
@@ -151,7 +150,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
|
||||
if (IBVerbs::ibv_query_gid(ctx, port, 0, &gid) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_gid failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
this->info.spn = gid.global.subnet_prefix;
|
||||
this->info.iid = gid.global.interface_id;
|
||||
@@ -166,7 +165,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
|
||||
if (IBVerbs::ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
this->qp = _qp;
|
||||
this->wrn = 0;
|
||||
@@ -210,7 +209,7 @@ void IbQp::rtr(const IbQpInfo& info) {
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,7 +228,7 @@ void IbQp::rts() {
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_modify_qp failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,7 +236,7 @@ IbQp::WrInfo IbQp::getNewWrInfo() {
|
||||
if (this->wrn >= this->maxWrPerSend) {
|
||||
std::stringstream err;
|
||||
err << "too many outstanding work requests. limit is " << this->maxWrPerSend;
|
||||
throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage);
|
||||
throw Error(err.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
int wrn = this->wrn;
|
||||
|
||||
@@ -306,7 +305,7 @@ void IbQp::postSend() {
|
||||
if (ret != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_post_send failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
this->wrn = 0;
|
||||
this->numSignaledPostedItems += this->numSignaledStagedItems;
|
||||
@@ -332,7 +331,7 @@ int IbQp::getNumCqItems() const { return this->numSignaledPostedItems; }
|
||||
IbCtx::IbCtx(const std::string& devName) : devName(devName) {
|
||||
#if !defined(__HIP_PLATFORM_AMD__)
|
||||
if (!checkNvPeerMemLoaded()) {
|
||||
throw mscclpp::Error("nvidia_peermem kernel module is not loaded", ErrorCode::InternalError);
|
||||
throw Error("nvidia_peermem kernel module is not loaded", ErrorCode::InternalError);
|
||||
}
|
||||
#endif // !defined(__HIP_PLATFORM_AMD__)
|
||||
int num;
|
||||
@@ -347,13 +346,13 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) {
|
||||
if (this->ctx == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
this->pd = IBVerbs::ibv_alloc_pd(this->ctx);
|
||||
if (this->pd == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_alloc_pd failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,7 +372,7 @@ bool IbCtx::isPortUsable(int port) const {
|
||||
if (IBVerbs::ibv_query_port_w(this->ctx, port, &portAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
return portAttr.state == IBV_PORT_ACTIVE &&
|
||||
(portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND);
|
||||
@@ -384,7 +383,7 @@ int IbCtx::getAnyActivePort() const {
|
||||
if (IBVerbs::ibv_query_device(this->ctx, &devAttr) != 0) {
|
||||
std::stringstream err;
|
||||
err << "ibv_query_device failed (errno " << errno << ")";
|
||||
throw mscclpp::IbError(err.str(), errno);
|
||||
throw IbError(err.str(), errno);
|
||||
}
|
||||
for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) {
|
||||
if (this->isPortUsable(port)) {
|
||||
@@ -399,10 +398,10 @@ IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRec
|
||||
if (port == -1) {
|
||||
port = this->getAnyActivePort();
|
||||
if (port == -1) {
|
||||
throw mscclpp::Error("No active port found", ErrorCode::InvalidUsage);
|
||||
throw Error("No active port found", ErrorCode::InvalidUsage);
|
||||
}
|
||||
} else if (!this->isPortUsable(port)) {
|
||||
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage);
|
||||
throw Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage);
|
||||
}
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend));
|
||||
return qps.back().get();
|
||||
|
||||
@@ -16,12 +16,13 @@ namespace mscclpp {
|
||||
class CudaIpcStream {
|
||||
private:
|
||||
std::shared_ptr<CudaStreamWithFlags> stream_;
|
||||
int deviceId_;
|
||||
bool dirty_;
|
||||
|
||||
void setStreamIfNeeded();
|
||||
|
||||
public:
|
||||
CudaIpcStream();
|
||||
CudaIpcStream(int deviceId);
|
||||
|
||||
void memcpyD2D(void *dst, const void *src, size_t nbytes);
|
||||
|
||||
@@ -30,6 +31,8 @@ class CudaIpcStream {
|
||||
void sync();
|
||||
|
||||
operator cudaStream_t() const { return *stream_; }
|
||||
|
||||
int deviceId() const { return deviceId_; }
|
||||
};
|
||||
|
||||
struct Context::Impl {
|
||||
|
||||
@@ -19,6 +19,7 @@ struct Endpoint::Impl {
|
||||
Impl(const std::vector<char>& serialization);
|
||||
|
||||
Transport transport_;
|
||||
Device device_;
|
||||
uint64_t hostHash_;
|
||||
int maxWriteQueueSize_;
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ struct RegisteredMemory::Impl {
|
||||
int fileDesc = -1;
|
||||
|
||||
Impl(void* data, size_t size, TransportFlags transports, Context::Impl& contextImpl);
|
||||
Impl(const std::vector<char>::const_iterator& begin, const std::vector<char>::const_iterator& end);
|
||||
/// Constructs a RegisteredMemory::Impl from a vector of data. The constructor should only be used for the remote
|
||||
/// memory.
|
||||
Impl(const std::vector<char>& data);
|
||||
|
||||
26
src/include/serialization.hpp
Normal file
26
src/include/serialization.hpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef MSCCLPP_SERIALIZATION_HPP_
|
||||
#define MSCCLPP_SERIALIZATION_HPP_
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
namespace mscclpp::detail {
|
||||
|
||||
template <typename T>
|
||||
void serialize(std::vector<char>& buffer, const T& value) {
|
||||
const char* data = reinterpret_cast<const char*>(&value);
|
||||
std::copy_n(data, sizeof(T), std::back_inserter(buffer));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<char>::const_iterator deserialize(const std::vector<char>::const_iterator& pos, T& value) {
|
||||
std::copy_n(pos, sizeof(T), reinterpret_cast<char*>(&value));
|
||||
return pos + sizeof(T);
|
||||
}
|
||||
|
||||
} // namespace mscclpp::detail
|
||||
|
||||
#endif // MSCCLPP_SERIALIZATION_HPP_
|
||||
@@ -11,6 +11,9 @@ namespace mscclpp {
|
||||
MSCCLPP_API_CPP BaseMemoryChannel::BaseMemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore)
|
||||
: semaphore_(semaphore) {}
|
||||
|
||||
MSCCLPP_API_CPP BaseMemoryChannel::BaseMemoryChannel(const Semaphore& semaphore)
|
||||
: BaseMemoryChannel(std::make_shared<MemoryDevice2DeviceSemaphore>(semaphore)) {}
|
||||
|
||||
MSCCLPP_API_CPP MemoryChannel::MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, void* src, void* packetBuffer)
|
||||
: BaseMemoryChannel(semaphore), dst_(dst), src_(src), packetBuffer_(packetBuffer) {
|
||||
@@ -19,6 +22,10 @@ MSCCLPP_API_CPP MemoryChannel::MemoryChannel(std::shared_ptr<MemoryDevice2Device
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP MemoryChannel::MemoryChannel(const Semaphore& semaphore, RegisteredMemory dst, void* src,
|
||||
void* packetBuffer)
|
||||
: MemoryChannel(std::make_shared<MemoryDevice2DeviceSemaphore>(semaphore), dst, src, packetBuffer) {}
|
||||
|
||||
MSCCLPP_API_CPP BaseMemoryChannel::DeviceHandle BaseMemoryChannel::deviceHandle() const {
|
||||
return BaseMemoryChannel::DeviceHandle(semaphore_->deviceHandle());
|
||||
}
|
||||
|
||||
@@ -14,10 +14,18 @@ MSCCLPP_API_CPP BasePortChannel::BasePortChannel(SemaphoreId semaphoreId,
|
||||
std::shared_ptr<Proxy> proxy)
|
||||
: semaphoreId_(semaphoreId), semaphore_(semaphore), proxy_(proxy) {}
|
||||
|
||||
MSCCLPP_API_CPP BasePortChannel::BasePortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore,
|
||||
std::shared_ptr<Proxy> proxy)
|
||||
: BasePortChannel(semaphoreId, std::make_shared<Host2DeviceSemaphore>(semaphore), proxy) {}
|
||||
|
||||
MSCCLPP_API_CPP PortChannel::PortChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore,
|
||||
std::shared_ptr<Proxy> proxy, MemoryId dst, MemoryId src)
|
||||
: BasePortChannel(semaphoreId, semaphore, proxy), dst_(dst), src_(src) {}
|
||||
|
||||
MSCCLPP_API_CPP PortChannel::PortChannel(SemaphoreId semaphoreId, const Semaphore& semaphore,
|
||||
std::shared_ptr<Proxy> proxy, MemoryId dst, MemoryId src)
|
||||
: BasePortChannel(semaphoreId, semaphore, proxy), dst_(dst), src_(src) {}
|
||||
|
||||
MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
@@ -39,6 +47,11 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& com
|
||||
return semaphores_.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(const Semaphore& semaphore) {
|
||||
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(semaphore));
|
||||
return semaphores_.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Host2DeviceSemaphore> semaphore) {
|
||||
semaphores_.push_back(semaphore);
|
||||
return semaphores_.size() - 1;
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "api.h"
|
||||
#include "context.hpp"
|
||||
#include "debug.h"
|
||||
#include "serialization.hpp"
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
#define MSCCLPP_CULOG_WARN(cmd) \
|
||||
@@ -119,43 +120,37 @@ MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() const { return pim
|
||||
|
||||
MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() const {
|
||||
std::vector<char> result;
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->originalDataPtr), sizeof(pimpl_->originalDataPtr),
|
||||
std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->size), sizeof(pimpl_->size), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->baseDataSize), sizeof(pimpl_->baseDataSize), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->hostHash), sizeof(pimpl_->hostHash), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->pidHash), sizeof(pimpl_->pidHash), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->isCuMemMapAlloc), sizeof(pimpl_->isCuMemMapAlloc),
|
||||
std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&pimpl_->transports), sizeof(pimpl_->transports), std::back_inserter(result));
|
||||
detail::serialize(result, pimpl_->originalDataPtr);
|
||||
detail::serialize(result, pimpl_->size);
|
||||
detail::serialize(result, pimpl_->baseDataSize);
|
||||
detail::serialize(result, pimpl_->hostHash);
|
||||
detail::serialize(result, pimpl_->pidHash);
|
||||
detail::serialize(result, pimpl_->isCuMemMapAlloc);
|
||||
detail::serialize(result, pimpl_->transports);
|
||||
if (pimpl_->transportInfos.size() > static_cast<size_t>(std::numeric_limits<int8_t>::max())) {
|
||||
throw mscclpp::Error("Too many transport info entries", ErrorCode::InternalError);
|
||||
throw Error("Too many transport info entries", ErrorCode::InternalError);
|
||||
}
|
||||
int8_t transportCount = pimpl_->transportInfos.size();
|
||||
std::copy_n(reinterpret_cast<char*>(&transportCount), sizeof(transportCount), std::back_inserter(result));
|
||||
detail::serialize(result, transportCount);
|
||||
for (auto& entry : pimpl_->transportInfos) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
|
||||
detail::serialize(result, entry.transport);
|
||||
if (entry.transport == Transport::CudaIpc) {
|
||||
if (pimpl_->isCuMemMapAlloc) {
|
||||
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.shareableHandle), sizeof(entry.shareableHandle),
|
||||
std::back_inserter(result));
|
||||
detail::serialize(result, entry.shareableHandle);
|
||||
} else {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.rootPid), sizeof(entry.rootPid), std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.fileDesc), sizeof(entry.fileDesc), std::back_inserter(result));
|
||||
detail::serialize(result, entry.rootPid);
|
||||
detail::serialize(result, entry.fileDesc);
|
||||
}
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.offsetFromBase), sizeof(entry.offsetFromBase),
|
||||
std::back_inserter(result));
|
||||
detail::serialize(result, entry.offsetFromBase);
|
||||
} else {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle),
|
||||
std::back_inserter(result));
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase),
|
||||
std::back_inserter(result));
|
||||
detail::serialize(result, entry.cudaIpcBaseHandle);
|
||||
detail::serialize(result, entry.cudaIpcOffsetFromBase);
|
||||
}
|
||||
} else if (AllIBTransports.has(entry.transport)) {
|
||||
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
|
||||
detail::serialize(result, entry.ibMrInfo);
|
||||
} else {
|
||||
throw mscclpp::Error("Unknown transport", ErrorCode::InternalError);
|
||||
throw Error("Unknown transport", ErrorCode::InternalError);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@@ -165,62 +160,44 @@ MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector
|
||||
return RegisteredMemory(std::make_shared<Impl>(data));
|
||||
}
|
||||
|
||||
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
|
||||
auto it = serialization.begin();
|
||||
std::copy_n(it, sizeof(this->originalDataPtr), reinterpret_cast<char*>(&this->originalDataPtr));
|
||||
it += sizeof(this->originalDataPtr);
|
||||
std::copy_n(it, sizeof(this->size), reinterpret_cast<char*>(&this->size));
|
||||
it += sizeof(this->size);
|
||||
std::copy_n(it, sizeof(this->baseDataSize), reinterpret_cast<char*>(&this->baseDataSize));
|
||||
it += sizeof(this->baseDataSize);
|
||||
std::copy_n(it, sizeof(this->hostHash), reinterpret_cast<char*>(&this->hostHash));
|
||||
it += sizeof(this->hostHash);
|
||||
std::copy_n(it, sizeof(this->pidHash), reinterpret_cast<char*>(&this->pidHash));
|
||||
it += sizeof(this->pidHash);
|
||||
std::copy_n(it, sizeof(this->isCuMemMapAlloc), reinterpret_cast<char*>(&this->isCuMemMapAlloc));
|
||||
it += sizeof(this->isCuMemMapAlloc);
|
||||
std::copy_n(it, sizeof(this->transports), reinterpret_cast<char*>(&this->transports));
|
||||
it += sizeof(this->transports);
|
||||
RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
|
||||
const std::vector<char>::const_iterator& end) {
|
||||
auto it = begin;
|
||||
it = detail::deserialize(it, this->originalDataPtr);
|
||||
it = detail::deserialize(it, this->size);
|
||||
it = detail::deserialize(it, this->baseDataSize);
|
||||
it = detail::deserialize(it, this->hostHash);
|
||||
it = detail::deserialize(it, this->pidHash);
|
||||
it = detail::deserialize(it, this->isCuMemMapAlloc);
|
||||
it = detail::deserialize(it, this->transports);
|
||||
int8_t transportCount;
|
||||
std::copy_n(it, sizeof(transportCount), reinterpret_cast<char*>(&transportCount));
|
||||
it += sizeof(transportCount);
|
||||
it = detail::deserialize(it, transportCount);
|
||||
for (int i = 0; i < transportCount; ++i) {
|
||||
TransportInfo transportInfo;
|
||||
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
|
||||
it += sizeof(transportInfo.transport);
|
||||
it = detail::deserialize(it, transportInfo.transport);
|
||||
if (transportInfo.transport == Transport::CudaIpc) {
|
||||
if (this->isCuMemMapAlloc) {
|
||||
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) {
|
||||
std::copy_n(it, sizeof(transportInfo.shareableHandle),
|
||||
reinterpret_cast<char*>(&transportInfo.shareableHandle));
|
||||
it += sizeof(transportInfo.shareableHandle);
|
||||
it = detail::deserialize(it, transportInfo.shareableHandle);
|
||||
} else {
|
||||
std::copy_n(it, sizeof(transportInfo.rootPid), reinterpret_cast<char*>(&transportInfo.rootPid));
|
||||
it += sizeof(transportInfo.rootPid);
|
||||
std::copy_n(it, sizeof(transportInfo.fileDesc), reinterpret_cast<char*>(&transportInfo.fileDesc));
|
||||
it += sizeof(transportInfo.fileDesc);
|
||||
it = detail::deserialize(it, transportInfo.rootPid);
|
||||
it = detail::deserialize(it, transportInfo.fileDesc);
|
||||
}
|
||||
std::copy_n(it, sizeof(transportInfo.offsetFromBase), reinterpret_cast<char*>(&transportInfo.offsetFromBase));
|
||||
it += sizeof(transportInfo.offsetFromBase);
|
||||
it = detail::deserialize(it, transportInfo.offsetFromBase);
|
||||
} else {
|
||||
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle),
|
||||
reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
|
||||
it += sizeof(transportInfo.cudaIpcBaseHandle);
|
||||
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase),
|
||||
reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
|
||||
it += sizeof(transportInfo.cudaIpcOffsetFromBase);
|
||||
it = detail::deserialize(it, transportInfo.cudaIpcBaseHandle);
|
||||
it = detail::deserialize(it, transportInfo.cudaIpcOffsetFromBase);
|
||||
}
|
||||
} else if (AllIBTransports.has(transportInfo.transport)) {
|
||||
std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast<char*>(&transportInfo.ibMrInfo));
|
||||
it += sizeof(transportInfo.ibMrInfo);
|
||||
it = detail::deserialize(it, transportInfo.ibMrInfo);
|
||||
transportInfo.ibLocal = false;
|
||||
} else {
|
||||
throw mscclpp::Error("Unknown transport", ErrorCode::InternalError);
|
||||
throw Error("Unknown transport", ErrorCode::InternalError);
|
||||
}
|
||||
this->transportInfos.push_back(transportInfo);
|
||||
}
|
||||
if (it != serialization.end()) {
|
||||
throw mscclpp::Error("Serialization failed", ErrorCode::InternalError);
|
||||
if (it != end) {
|
||||
throw Error("Serialization failed", ErrorCode::InternalError);
|
||||
}
|
||||
|
||||
// Next decide how to set this->data
|
||||
@@ -239,11 +216,11 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
|
||||
} else {
|
||||
int rootPidFd = syscall(SYS_pidfd_open, entry.rootPid, 0);
|
||||
if (rootPidFd < 0) {
|
||||
throw mscclpp::SysError("pidfd_open() failed", errno);
|
||||
throw SysError("pidfd_open() failed", errno);
|
||||
}
|
||||
int fd = syscall(SYS_pidfd_getfd, rootPidFd, entry.fileDesc, 0);
|
||||
if (fd < 0) {
|
||||
throw mscclpp::SysError("pidfd_getfd() failed", errno);
|
||||
throw SysError("pidfd_getfd() failed", errno);
|
||||
}
|
||||
INFO(MSCCLPP_P2P, "Get file descriptor %d from pidfd %d on peer 0x%lx", fd, rootPidFd, hostHash);
|
||||
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, reinterpret_cast<void*>(fd),
|
||||
@@ -260,9 +237,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
|
||||
detail::setReadWriteMemoryAccess(base, size);
|
||||
this->data = static_cast<char*>(base) + entry.offsetFromBase;
|
||||
#else
|
||||
throw mscclpp::Error(
|
||||
"CUDA does not support NVLS. Please ensure your CUDA version supports NVLS to use this feature.",
|
||||
mscclpp::ErrorCode::InvalidUsage);
|
||||
throw Error("CUDA does not support NVLS. Please ensure your CUDA version supports NVLS to use this feature.",
|
||||
ErrorCode::InvalidUsage);
|
||||
#endif
|
||||
} else {
|
||||
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
@@ -275,6 +251,9 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
|
||||
}
|
||||
}
|
||||
|
||||
RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
||||
: Impl(serialization.begin(), serialization.end()) {}
|
||||
|
||||
RegisteredMemory::Impl::~Impl() {
|
||||
// Close the CUDA IPC handle if it was opened during deserialization
|
||||
if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) {
|
||||
|
||||
209
src/semaphore.cc
209
src/semaphore.cc
@@ -5,116 +5,193 @@
|
||||
|
||||
#include "api.h"
|
||||
#include "atomic.hpp"
|
||||
#include "context.hpp"
|
||||
#include "debug.h"
|
||||
#include "registered_memory.hpp"
|
||||
#include "serialization.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
static std::shared_future<RegisteredMemory> setupInboundSemaphoreId(Communicator& communicator, Connection* connection,
|
||||
void* localInboundSemaphoreId) {
|
||||
auto localInboundSemaphoreIdsRegMem =
|
||||
communicator.registerMemory(localInboundSemaphoreId, sizeof(uint64_t), connection->transport());
|
||||
int remoteRank = communicator.remoteRankOf(*connection);
|
||||
int tag = communicator.tagOf(*connection);
|
||||
communicator.sendMemory(localInboundSemaphoreIdsRegMem, remoteRank, tag);
|
||||
return communicator.recvMemory(remoteRank, tag);
|
||||
struct SemaphoreStub::Impl {
|
||||
Impl(std::shared_ptr<Connection> connection);
|
||||
|
||||
Impl(const RegisteredMemory& idMemory, const Device& device);
|
||||
|
||||
Impl(const std::vector<char>& data);
|
||||
|
||||
std::shared_ptr<Connection> connection_;
|
||||
std::shared_ptr<uint64_t> token_;
|
||||
RegisteredMemory idMemory_;
|
||||
Device device_;
|
||||
};
|
||||
|
||||
static std::shared_ptr<uint64_t> gpuCallocToken() {
|
||||
#if defined(MSCCLPP_DEVICE_HIP)
|
||||
return detail::gpuCallocUncachedShared<uint64_t>();
|
||||
#else // !defined(MSCCLPP_DEVICE_HIP)
|
||||
return detail::gpuCallocShared<uint64_t>();
|
||||
#endif // !defined(MSCCLPP_DEVICE_HIP)
|
||||
}
|
||||
|
||||
static detail::UniqueGpuPtr<uint64_t> createGpuSemaphoreId() {
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
return detail::gpuCallocUncachedUnique<uint64_t>();
|
||||
#else // !defined(__HIP_PLATFORM_AMD__)
|
||||
return detail::gpuCallocUnique<uint64_t>();
|
||||
#endif // !defined(__HIP_PLATFORM_AMD__)
|
||||
SemaphoreStub::Impl::Impl(std::shared_ptr<Connection> connection) : connection_(connection) {
|
||||
// Allocate a semaphore ID on the local device
|
||||
const Device& localDevice = connection_->localDevice();
|
||||
if (localDevice.type == DeviceType::CPU) {
|
||||
token_ = std::make_shared<uint64_t>(0);
|
||||
} else if (localDevice.type == DeviceType::GPU) {
|
||||
if (localDevice.id < 0) {
|
||||
throw Error("Local GPU ID is not provided", ErrorCode::InvalidUsage);
|
||||
}
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(localDevice.id));
|
||||
token_ = gpuCallocToken();
|
||||
} else {
|
||||
throw Error("Unsupported local device type", ErrorCode::InvalidUsage);
|
||||
}
|
||||
idMemory_ =
|
||||
std::move(connection->context()->registerMemory(token_.get(), sizeof(uint64_t), connection_->transport()));
|
||||
}
|
||||
|
||||
SemaphoreStub::Impl::Impl(const RegisteredMemory& idMemory, const Device& device)
|
||||
: idMemory_(idMemory), device_(device) {}
|
||||
|
||||
SemaphoreStub::SemaphoreStub(std::shared_ptr<Impl> pimpl) : pimpl_(std::move(pimpl)) {}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreStub::SemaphoreStub(std::shared_ptr<Connection> connection)
|
||||
: pimpl_(std::make_shared<Impl>(connection)) {}
|
||||
|
||||
MSCCLPP_API_CPP std::vector<char> SemaphoreStub::serialize() const {
|
||||
auto data = pimpl_->idMemory_.serialize();
|
||||
detail::serialize(data, pimpl_->device_);
|
||||
return data;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreStub SemaphoreStub::deserialize(const std::vector<char>& data) {
|
||||
Device device;
|
||||
auto memEnd = data.end() - sizeof(device);
|
||||
RegisteredMemory idMemory(std::make_shared<RegisteredMemory::Impl>(data.begin(), memEnd));
|
||||
auto it = detail::deserialize(memEnd, device);
|
||||
if (it != data.end()) {
|
||||
throw Error("SemaphoreStub deserialize failed", ErrorCode::InvalidUsage);
|
||||
}
|
||||
return SemaphoreStub(std::make_shared<Impl>(std::move(idMemory), device));
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP const RegisteredMemory& SemaphoreStub::memory() const { return pimpl_->idMemory_; }
|
||||
|
||||
struct Semaphore::Impl {
|
||||
Impl(const SemaphoreStub& localStub, const RegisteredMemory& remoteStubMemory)
|
||||
: localStub_(localStub), remoteStubMemory_(remoteStubMemory) {}
|
||||
|
||||
SemaphoreStub localStub_;
|
||||
RegisteredMemory remoteStubMemory_;
|
||||
};
|
||||
|
||||
Semaphore::Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub)
|
||||
: pimpl_(std::make_unique<Impl>(localStub, remoteStub.memory())) {}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Semaphore::connection() const {
|
||||
return pimpl_->localStub_.pimpl_->connection_;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP const RegisteredMemory& Semaphore::localMemory() const { return pimpl_->localStub_.memory(); }
|
||||
|
||||
MSCCLPP_API_CPP const RegisteredMemory& Semaphore::remoteMemory() const { return pimpl_->remoteStubMemory_; }
|
||||
|
||||
static Semaphore buildSemaphoreFromConnection(Communicator& communicator, std::shared_ptr<Connection> connection) {
|
||||
auto semaphoreFuture =
|
||||
communicator.buildSemaphore(connection, communicator.remoteRankOf(*connection), communicator.tagOf(*connection));
|
||||
return semaphoreFuture.get();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(const Semaphore& semaphore)
|
||||
: semaphore_(semaphore),
|
||||
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()),
|
||||
outboundToken_(std::make_unique<uint64_t>()) {
|
||||
if (connection()->localDevice().type != DeviceType::GPU) {
|
||||
throw Error("Local endpoint device type of Host2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator,
|
||||
std::shared_ptr<Connection> connection)
|
||||
: BaseSemaphore(createGpuSemaphoreId(), createGpuSemaphoreId(), std::make_unique<uint64_t>()),
|
||||
connection_(connection) {
|
||||
INFO(MSCCLPP_INIT, "Creating a Host2Device semaphore for %s transport from %d to %d",
|
||||
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
|
||||
communicator.remoteRankOf(*connection));
|
||||
remoteInboundSemaphoreIdsRegMem_ =
|
||||
setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get());
|
||||
}
|
||||
: Host2DeviceSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2DeviceSemaphore::connection() { return connection_; }
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2DeviceSemaphore::connection() const { return semaphore_.connection(); }
|
||||
|
||||
MSCCLPP_API_CPP void Host2DeviceSemaphore::signal() {
|
||||
connection_->updateAndSync(remoteInboundSemaphoreIdsRegMem_.get(), 0, outboundSemaphore_.get(),
|
||||
*outboundSemaphore_ + 1);
|
||||
connection()->updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Host2DeviceSemaphore::DeviceHandle Host2DeviceSemaphore::deviceHandle() {
|
||||
MSCCLPP_API_CPP Host2DeviceSemaphore::DeviceHandle Host2DeviceSemaphore::deviceHandle() const {
|
||||
Host2DeviceSemaphore::DeviceHandle device;
|
||||
device.inboundSemaphoreId = localInboundSemaphore_.get();
|
||||
device.expectedInboundSemaphoreId = expectedInboundSemaphore_.get();
|
||||
device.inboundToken = reinterpret_cast<uint64_t*>(semaphore_.localMemory().data());
|
||||
device.expectedInboundToken = expectedInboundToken_.get();
|
||||
return device;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(const Semaphore& semaphore)
|
||||
: semaphore_(semaphore),
|
||||
expectedInboundToken_(std::make_unique<uint64_t>()),
|
||||
outboundToken_(std::make_unique<uint64_t>()) {
|
||||
if (connection()->transport() == Transport::CudaIpc) {
|
||||
throw Error("Host2HostSemaphore cannot be used with CudaIpc transport", ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (connection()->localDevice().type != DeviceType::CPU) {
|
||||
throw Error("Local endpoint device type of Host2HostSemaphore should be CPU", ErrorCode::InvalidUsage);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(Communicator& communicator,
|
||||
std::shared_ptr<Connection> connection)
|
||||
: BaseSemaphore(std::make_unique<uint64_t>(), std::make_unique<uint64_t>(), std::make_unique<uint64_t>()),
|
||||
connection_(connection) {
|
||||
INFO(MSCCLPP_INIT, "Creating a Host2Host semaphore for %s transport from %d to %d",
|
||||
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
|
||||
communicator.remoteRankOf(*connection));
|
||||
: Host2HostSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
|
||||
|
||||
if (connection->transport() == Transport::CudaIpc) {
|
||||
throw Error("Host2HostSemaphore cannot be used with CudaIpc transport", ErrorCode::InvalidUsage);
|
||||
}
|
||||
remoteInboundSemaphoreIdsRegMem_ =
|
||||
setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2HostSemaphore::connection() { return connection_; }
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2HostSemaphore::connection() const { return semaphore_.connection(); }
|
||||
|
||||
MSCCLPP_API_CPP void Host2HostSemaphore::signal() {
|
||||
connection_->updateAndSync(remoteInboundSemaphoreIdsRegMem_.get(), 0, outboundSemaphore_.get(),
|
||||
*outboundSemaphore_ + 1);
|
||||
connection()->updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP bool Host2HostSemaphore::poll() {
|
||||
bool signaled =
|
||||
(atomicLoad((uint64_t*)localInboundSemaphore_.get(), memoryOrderAcquire) > (*expectedInboundSemaphore_));
|
||||
if (signaled) (*expectedInboundSemaphore_) += 1;
|
||||
bool signaled = (atomicLoad(reinterpret_cast<uint64_t*>(semaphore_.localMemory().data()), memoryOrderAcquire) >
|
||||
(*expectedInboundToken_));
|
||||
if (signaled) (*expectedInboundToken_) += 1;
|
||||
return signaled;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Host2HostSemaphore::wait(int64_t maxSpinCount) {
|
||||
(*expectedInboundSemaphore_) += 1;
|
||||
(*expectedInboundToken_) += 1;
|
||||
int64_t spinCount = 0;
|
||||
while (atomicLoad((uint64_t*)localInboundSemaphore_.get(), memoryOrderAcquire) < (*expectedInboundSemaphore_)) {
|
||||
while (atomicLoad(reinterpret_cast<uint64_t*>(semaphore_.localMemory().data()), memoryOrderAcquire) <
|
||||
(*expectedInboundToken_)) {
|
||||
if (maxSpinCount >= 0 && spinCount++ == maxSpinCount) {
|
||||
throw Error("Host2HostSemaphore::wait timed out", ErrorCode::Timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(const Semaphore& semaphore)
|
||||
: semaphore_(semaphore),
|
||||
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()),
|
||||
outboundToken_(detail::gpuCallocUnique<uint64_t>()) {
|
||||
if (connection()->localDevice().type != DeviceType::GPU) {
|
||||
throw Error("Local endpoint device type of MemoryDevice2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(Communicator& communicator,
|
||||
std::shared_ptr<Connection> connection)
|
||||
: BaseSemaphore(createGpuSemaphoreId(), createGpuSemaphoreId(), createGpuSemaphoreId()) {
|
||||
INFO(MSCCLPP_INIT, "Creating a Device2Device semaphore for %s transport from %d to %d",
|
||||
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
|
||||
communicator.remoteRankOf(*connection));
|
||||
if (connection->transport() == Transport::CudaIpc) {
|
||||
remoteInboundSemaphoreIdsRegMem_ =
|
||||
setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get());
|
||||
isRemoteInboundSemaphoreIdSet_ = true;
|
||||
} else if (AllIBTransports.has(connection->transport())) {
|
||||
// Should we throw an error here?
|
||||
isRemoteInboundSemaphoreIdSet_ = false;
|
||||
}
|
||||
: MemoryDevice2DeviceSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> MemoryDevice2DeviceSemaphore::connection() const {
|
||||
return semaphore_.connection();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::DeviceHandle MemoryDevice2DeviceSemaphore::deviceHandle() const {
|
||||
MemoryDevice2DeviceSemaphore::DeviceHandle device;
|
||||
device.remoteInboundSemaphoreId = isRemoteInboundSemaphoreIdSet_
|
||||
? reinterpret_cast<uint64_t*>(remoteInboundSemaphoreIdsRegMem_.get().data())
|
||||
: nullptr;
|
||||
device.inboundSemaphoreId = localInboundSemaphore_.get();
|
||||
device.expectedInboundSemaphoreId = expectedInboundSemaphore_.get();
|
||||
device.outboundSemaphoreId = outboundSemaphore_.get();
|
||||
device.remoteInboundToken = reinterpret_cast<uint64_t*>(semaphore_.remoteMemory().data());
|
||||
device.inboundToken = reinterpret_cast<uint64_t*>(semaphore_.localMemory().data());
|
||||
device.expectedInboundToken = expectedInboundToken_.get();
|
||||
device.outboundToken = outboundToken_.get();
|
||||
return device;
|
||||
};
|
||||
|
||||
|
||||
@@ -225,16 +225,17 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
transport = ibTransport;
|
||||
}
|
||||
// Connect with all other ranks
|
||||
connections[r] = comm.connect(r, 0, transport);
|
||||
connections[r] = comm.connect(transport, r);
|
||||
auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
localMemories.push_back(memory);
|
||||
comm.sendMemory(memory, r, 0);
|
||||
remoteMemories.push_back(comm.recvMemory(r, 0));
|
||||
comm.sendMemory(memory, r);
|
||||
remoteMemories.push_back(comm.recvMemory(r));
|
||||
}
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) continue;
|
||||
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get()));
|
||||
auto sema = comm.buildSemaphore(connections[r].get(), r).get();
|
||||
semaphoreIds.push_back(proxyService.addSemaphore(sema));
|
||||
}
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;
|
||||
|
||||
@@ -83,7 +83,6 @@ class MyProxyService {
|
||||
int dataSize_;
|
||||
std::vector<mscclpp::RegisteredMemory> remoteMemories_;
|
||||
mscclpp::RegisteredMemory localMemory_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2HostSemaphore>> hostSemaphores_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> deviceSemaphores1_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> deviceSemaphores2_;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections_;
|
||||
@@ -105,7 +104,6 @@ class MyProxyService {
|
||||
localMemory_ = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) {
|
||||
hostSemaphores_.emplace_back(nullptr);
|
||||
deviceSemaphores1_.emplace_back(nullptr);
|
||||
deviceSemaphores2_.emplace_back(nullptr);
|
||||
continue;
|
||||
@@ -117,10 +115,10 @@ class MyProxyService {
|
||||
transport = ibTransport;
|
||||
}
|
||||
// Connect with all other ranks
|
||||
connectionsFuture[r] = comm.connect(r, 0, transport);
|
||||
comm.sendMemory(localMemory_, r, 0);
|
||||
connectionsFuture[r] = comm.connect(transport, r);
|
||||
comm.sendMemory(localMemory_, r);
|
||||
|
||||
remoteMemoriesFuture[r] = comm.recvMemory(r, 0);
|
||||
remoteMemoriesFuture[r] = comm.recvMemory(r);
|
||||
}
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
@@ -128,11 +126,6 @@ class MyProxyService {
|
||||
continue;
|
||||
}
|
||||
connections_[r] = connectionsFuture[r].get();
|
||||
if (rankToNode(r) == thisNode) {
|
||||
hostSemaphores_.emplace_back(nullptr);
|
||||
} else {
|
||||
hostSemaphores_.emplace_back(std::make_shared<mscclpp::Host2HostSemaphore>(comm, connections_[r]));
|
||||
}
|
||||
deviceSemaphores1_.emplace_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(comm, connections_[r]));
|
||||
deviceSemaphores2_.emplace_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(comm, connections_[r]));
|
||||
remoteMemories_[r] = remoteMemoriesFuture[r].get();
|
||||
|
||||
@@ -47,11 +47,11 @@ void CommunicatorTestBase::connectMesh(bool useIpc, bool useIb, bool useEthernet
|
||||
for (int i = 0; i < numRanksToUse; i++) {
|
||||
if (i != gEnv->rank) {
|
||||
if ((rankToNode(i) == rankToNode(gEnv->rank)) && useIpc) {
|
||||
connectionFutures[i] = communicator->connect(i, 0, mscclpp::Transport::CudaIpc);
|
||||
connectionFutures[i] = communicator->connect(mscclpp::Transport::CudaIpc, i);
|
||||
} else if (useIb) {
|
||||
connectionFutures[i] = communicator->connect(i, 0, ibTransport);
|
||||
connectionFutures[i] = communicator->connect(ibTransport, i);
|
||||
} else if (useEthernet) {
|
||||
connectionFutures[i] = communicator->connect(i, 0, mscclpp::Transport::Ethernet);
|
||||
connectionFutures[i] = communicator->connect(mscclpp::Transport::Ethernet, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,28 +39,26 @@ void MemoryChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::Memory
|
||||
continue;
|
||||
}
|
||||
if (rankToNode(r) == rankToNode(gEnv->rank)) {
|
||||
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::CudaIpc);
|
||||
connectionFutures[r] = communicator->connect(mscclpp::Transport::CudaIpc, r);
|
||||
} else {
|
||||
connectionFutures[r] = communicator->connect(r, 0, ibTransport);
|
||||
connectionFutures[r] = communicator->connect(ibTransport, r);
|
||||
}
|
||||
|
||||
if (isInPlace) {
|
||||
communicator->sendMemory(inputBufRegMem, r, 0);
|
||||
communicator->sendMemory(inputBufRegMem, r);
|
||||
} else {
|
||||
communicator->sendMemory(outputBufRegMem, r, 0);
|
||||
communicator->sendMemory(outputBufRegMem, r);
|
||||
}
|
||||
remoteMemFutures[r] = communicator->recvMemory(r, 0);
|
||||
remoteMemFutures[r] = communicator->recvMemory(r);
|
||||
}
|
||||
|
||||
for (int r = 0; r < worldSize; r++) {
|
||||
if (r == rank) {
|
||||
continue;
|
||||
}
|
||||
connections[r] = connectionFutures[r].get();
|
||||
auto sema = communicator->buildSemaphore(connectionFutures[r].get(), r).get();
|
||||
|
||||
memorySemaphores[r] = std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*communicator, connections[r]);
|
||||
|
||||
memoryChannels.emplace_back(memorySemaphores[r], remoteMemFutures[r].get(), inputBufRegMem.data(),
|
||||
memoryChannels.emplace_back(sema, remoteMemFutures[r].get(), inputBufRegMem.data(),
|
||||
(isInPlace ? nullptr : outputBufRegMem.data()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,26 +42,28 @@ void PortChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::PortChan
|
||||
continue;
|
||||
}
|
||||
if ((rankToNode(r) == rankToNode(gEnv->rank)) && useIPC) {
|
||||
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::CudaIpc);
|
||||
connectionFutures[r] = communicator->connect(mscclpp::Transport::CudaIpc, r);
|
||||
} else if (useIb) {
|
||||
connectionFutures[r] = communicator->connect(r, 0, ibTransport);
|
||||
connectionFutures[r] = communicator->connect(ibTransport, r);
|
||||
} else if (useEthernet) {
|
||||
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::Ethernet);
|
||||
connectionFutures[r] = communicator->connect(mscclpp::Transport::Ethernet, r);
|
||||
}
|
||||
|
||||
if (isInPlace) {
|
||||
communicator->sendMemory(sendBufRegMem, r, 0);
|
||||
communicator->sendMemory(sendBufRegMem, r);
|
||||
} else {
|
||||
communicator->sendMemory(recvBufRegMem, r, 0);
|
||||
communicator->sendMemory(recvBufRegMem, r);
|
||||
}
|
||||
remoteMemFutures[r] = communicator->recvMemory(r, 0);
|
||||
remoteMemFutures[r] = communicator->recvMemory(r);
|
||||
}
|
||||
|
||||
for (int r = 0; r < worldSize; r++) {
|
||||
if (r == rank) {
|
||||
continue;
|
||||
}
|
||||
mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(*communicator, connectionFutures[r].get());
|
||||
auto sema = communicator->buildSemaphore(connectionFutures[r].get(), r).get();
|
||||
|
||||
mscclpp::SemaphoreId cid = proxyService->addSemaphore(sema);
|
||||
|
||||
portChannels.emplace_back(proxyService->portChannel(cid, proxyService->addMemory(remoteMemFutures[r].get()),
|
||||
proxyService->addMemory(sendBufRegMem)));
|
||||
|
||||
@@ -386,10 +386,10 @@ void BaseTestEngine::setupMeshConnectionsInternal(
|
||||
transport = ibTransport;
|
||||
}
|
||||
// Connect with all other ranks
|
||||
connectionFutures.push_back(comm_->connect(r, 0, transport));
|
||||
connectionFutures.push_back(comm_->connect(transport, r));
|
||||
}
|
||||
comm_->sendMemory(localRegMemory, r, 0);
|
||||
auto remoteMemory = comm_->recvMemory(r, 0);
|
||||
comm_->sendMemory(localRegMemory, r);
|
||||
auto remoteMemory = comm_->recvMemory(r);
|
||||
remoteRegMemories.push_back(remoteMemory);
|
||||
}
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
|
||||
@@ -153,18 +153,20 @@ void SendRecvTestEngine::setupConnections() {
|
||||
std::array<int, 2> ranks = {sendToRank, recvFromRank};
|
||||
auto service = std::dynamic_pointer_cast<mscclpp::ProxyService>(chanService_);
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
|
||||
|
||||
auto sendConnFuture =
|
||||
comm_->connect(sendToRank, 0, getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice));
|
||||
std::vector<mscclpp::Semaphore> semaphores;
|
||||
if (recvFromRank != sendToRank) {
|
||||
auto recvConnFuture =
|
||||
comm_->connect(recvFromRank, 0, getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice));
|
||||
memorySemaphores.push_back(std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, sendConnFuture.get()));
|
||||
memorySemaphores.push_back(std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, recvConnFuture.get()));
|
||||
auto sendTransport = getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice);
|
||||
auto recvTransport = getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice);
|
||||
auto connFutures = {comm_->connect(sendTransport, sendToRank), comm_->connect(recvTransport, recvFromRank)};
|
||||
auto semaFutures = {comm_->buildSemaphore(connFutures.begin()->get(), sendToRank),
|
||||
comm_->buildSemaphore((connFutures.begin() + 1)->get(), recvFromRank)};
|
||||
semaphores.emplace_back(semaFutures.begin()->get());
|
||||
semaphores.emplace_back((semaFutures.begin() + 1)->get());
|
||||
} else {
|
||||
memorySemaphores.push_back(std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, sendConnFuture.get()));
|
||||
memorySemaphores.push_back(memorySemaphores[0]); // reuse the send channel if worldSize is 2
|
||||
auto sendTransport = getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice);
|
||||
auto connFuture = comm_->connect(sendTransport, sendToRank);
|
||||
semaphores.emplace_back(comm_->buildSemaphore(connFuture.get(), sendToRank).get());
|
||||
semaphores.emplace_back(semaphores[0]); // reuse the semaphore if worldSize is 2
|
||||
}
|
||||
|
||||
std::vector<mscclpp::RegisteredMemory> localMemories;
|
||||
@@ -172,9 +174,9 @@ void SendRecvTestEngine::setupConnections() {
|
||||
|
||||
for (int i : {0, 1}) {
|
||||
auto regMem = comm_->registerMemory(devicePtrs_[i].get(), args_.maxBytes, mscclpp::Transport::CudaIpc | ibDevice);
|
||||
comm_->sendMemory(regMem, ranks[i], 0);
|
||||
comm_->sendMemory(regMem, ranks[i]);
|
||||
localMemories.push_back(regMem);
|
||||
futureRemoteMemory.push_back(comm_->recvMemory(ranks[1 - i], 0));
|
||||
futureRemoteMemory.push_back(comm_->recvMemory(ranks[1 - i]));
|
||||
}
|
||||
|
||||
// swap to make sure devicePtrs_[0] in local rank write to devicePtrs_[1] in remote rank
|
||||
@@ -182,7 +184,7 @@ void SendRecvTestEngine::setupConnections() {
|
||||
std::vector<DeviceHandle<mscclpp::MemoryChannel>> memoryChannelHandles(2);
|
||||
for (int i : {0, 1}) {
|
||||
// We assume ranks in the same node
|
||||
memoryChannels_.emplace_back(memorySemaphores[i], futureRemoteMemory[i].get(), (void*)localMemories[i].data());
|
||||
memoryChannels_.emplace_back(semaphores[i], futureRemoteMemory[i].get(), (void*)localMemories[i].data());
|
||||
}
|
||||
std::transform(memoryChannels_.begin(), memoryChannels_.end(), memoryChannelHandles.begin(),
|
||||
[](const mscclpp::MemoryChannel& memoryChannel) { return memoryChannel.deviceHandle(); });
|
||||
|
||||
@@ -29,8 +29,8 @@ TEST_F(LocalCommunicatorTest, RegisterMemory) {
|
||||
TEST_F(LocalCommunicatorTest, SendMemoryToSelf) {
|
||||
int dummy[42];
|
||||
auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports);
|
||||
comm->sendMemory(memory, 0, 0);
|
||||
auto memoryFuture = comm->recvMemory(0, 0);
|
||||
comm->sendMemory(memory, 0);
|
||||
auto memoryFuture = comm->recvMemory(0);
|
||||
auto sameMemory = memoryFuture.get();
|
||||
EXPECT_EQ(sameMemory.data(), memory.data());
|
||||
EXPECT_EQ(sameMemory.size(), memory.size());
|
||||
|
||||
@@ -46,7 +46,7 @@ static void localPortChannelTest(mscclpp::Transport transport) {
|
||||
bootstrap->initialize(mscclpp::TcpBootstrap::createUniqueId());
|
||||
auto communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
|
||||
auto connection = communicator->connect(/*remoteRank*/ 0, /*tag*/ 0, transport).get();
|
||||
auto connection = communicator->connect(transport, /*remoteRank*/ 0).get();
|
||||
|
||||
const size_t bytes = 4 * 1024 * 1024;
|
||||
auto srcBuff = mscclpp::GpuBuffer(bytes).memory();
|
||||
|
||||
Reference in New Issue
Block a user