mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
connect() APIs changed to return an instance instead of a shared_ptr (#680)
The key purpose is handling all mscclpp objects' memory internally by hiding shared pointers from user APIs. * `Connection` class is now a wrapper of `BaseConnection` class that is equivalent to the previous `Connection` class * `connect()` methods now return `Connection` instead of `std::shared_ptr<Connection>` * Removed `connectOnSetup()` method
This commit is contained in:
@@ -216,7 +216,7 @@ class AllgatherAlgo6 : public mscclpp::AlgorithmBuilder {
|
||||
|
||||
private:
|
||||
bool disableChannelCache_;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores_;
|
||||
const int nChannelsPerConnection_ = 35;
|
||||
|
||||
@@ -235,7 +235,7 @@ class AllgatherAlgo8 : public mscclpp::AlgorithmBuilder {
|
||||
mscclpp::Algorithm build() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
|
||||
void initialize(std::shared_ptr<mscclpp::Communicator> comm,
|
||||
std::unordered_map<std::string, std::shared_ptr<void>>& extras);
|
||||
|
||||
@@ -1137,7 +1137,7 @@ class AllreducePacket : public mscclpp::AlgorithmBuilder {
|
||||
size_t scratchBufferSize_;
|
||||
std::shared_ptr<char> scratchBuffer_;
|
||||
const int nSegmentsForScratchBuffer_ = 2;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
|
||||
std::shared_ptr<uint32_t> deviceFlag7_;
|
||||
std::shared_ptr<uint32_t> deviceFlag28_;
|
||||
@@ -1164,7 +1164,7 @@ class AllreduceNvls : public mscclpp::AlgorithmBuilder {
|
||||
uint32_t nSwitchChannels_;
|
||||
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>> memoryChannelsDeviceHandle_;
|
||||
std::vector<mscclpp::BaseMemoryChannel> baseChannels_;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
};
|
||||
|
||||
class AllreduceNvlsWithCopy : public mscclpp::AlgorithmBuilder {
|
||||
@@ -1188,7 +1188,7 @@ class AllreduceNvlsWithCopy : public mscclpp::AlgorithmBuilder {
|
||||
uint32_t nSwitchChannels_;
|
||||
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>> memoryChannelsDeviceHandle_;
|
||||
std::vector<mscclpp::BaseMemoryChannel> baseChannels_;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
};
|
||||
|
||||
class Allreduce8 : public mscclpp::AlgorithmBuilder {
|
||||
@@ -1209,7 +1209,7 @@ class Allreduce8 : public mscclpp::AlgorithmBuilder {
|
||||
size_t scratchBufferSize_;
|
||||
std::shared_ptr<mscclpp::Communicator> comm_;
|
||||
int nChannelsPerConnection_;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
std::shared_ptr<char> scratchBuffer_;
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> outputSemaphores_;
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> inputScratchSemaphores_;
|
||||
|
||||
@@ -171,7 +171,7 @@ class BroadcastAlgo6 : public mscclpp::AlgorithmBuilder {
|
||||
void* output, size_t, ncclDataType_t);
|
||||
mscclpp::AlgorithmCtxKey generateBroadcastContextKey(const void*, void*, size_t, ncclDataType_t);
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
size_t scratchMemSize_;
|
||||
std::shared_ptr<char> scratchBuffer_;
|
||||
};
|
||||
|
||||
@@ -20,7 +20,7 @@ std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<msccl
|
||||
}
|
||||
|
||||
std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
|
||||
const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
const std::vector<mscclpp::Connection>& connections,
|
||||
const std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores,
|
||||
const std::vector<mscclpp::RegisteredMemory>& remoteMemories, mscclpp::RegisteredMemory localMemory,
|
||||
int nChannelsPerConnection) {
|
||||
@@ -28,7 +28,7 @@ std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
|
||||
size_t nConnections = connections.size();
|
||||
for (int idx = 0; idx < nChannelsPerConnection; ++idx) {
|
||||
for (size_t cid = 0; cid < nConnections; ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
|
||||
channels.emplace_back(memorySemaphores[idx * nConnections + cid], remoteMemories[cid], localMemory, nullptr);
|
||||
}
|
||||
}
|
||||
@@ -36,25 +36,25 @@ std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
|
||||
return channels;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> setupConnections(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
std::vector<mscclpp::Connection> setupConnections(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
for (int i = 0; i < comm->bootstrap()->getNranks(); i++) {
|
||||
if (i == comm->bootstrap()->getRank()) continue;
|
||||
connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i));
|
||||
}
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::vector<mscclpp::Connection> connections;
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const auto& future) { return future.get(); });
|
||||
return connections;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
|
||||
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<mscclpp::Connection>& connections,
|
||||
int nChannelsPerConnection) {
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
|
||||
for (int idx = 0; idx < nChannelsPerConnection; ++idx) {
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
|
||||
memorySemaphores.emplace_back(
|
||||
std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*(comm), connections[cid]));
|
||||
}
|
||||
@@ -117,14 +117,14 @@ std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SwitchChannel>> setupNvlsChannelD
|
||||
}
|
||||
|
||||
std::vector<mscclpp::BaseMemoryChannel> setupBaseMemoryChannels(
|
||||
const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
const std::vector<mscclpp::Connection>& connections,
|
||||
const std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores,
|
||||
int nChannelsPerConnection) {
|
||||
std::vector<mscclpp::BaseMemoryChannel> channels;
|
||||
size_t nConnections = connections.size();
|
||||
for (int idx = 0; idx < nChannelsPerConnection; ++idx) {
|
||||
for (size_t cid = 0; cid < nConnections; ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
|
||||
channels.emplace_back(memorySemaphores[idx * nConnections + cid]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,15 +33,15 @@ std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<msccl
|
||||
mscclpp::RegisteredMemory localMemory);
|
||||
|
||||
std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
|
||||
const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
const std::vector<mscclpp::Connection>& connections,
|
||||
const std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores,
|
||||
const std::vector<mscclpp::RegisteredMemory>& remoteMemories, mscclpp::RegisteredMemory localMemory,
|
||||
int nChannelsPerConnection);
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> setupConnections(std::shared_ptr<mscclpp::Communicator> comm);
|
||||
std::vector<mscclpp::Connection> setupConnections(std::shared_ptr<mscclpp::Communicator> comm);
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
|
||||
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<mscclpp::Connection>& connections,
|
||||
int nChannelsPerConnection);
|
||||
|
||||
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> setupMemoryChannelDeviceHandles(
|
||||
@@ -57,7 +57,7 @@ std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SwitchChannel>> setupNvlsChannelD
|
||||
const std::vector<mscclpp::SwitchChannel>& nvlsChannels);
|
||||
|
||||
std::vector<mscclpp::BaseMemoryChannel> setupBaseMemoryChannels(
|
||||
const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
const std::vector<mscclpp::Connection>& connections,
|
||||
const std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores,
|
||||
int nChannelsPerConnection);
|
||||
|
||||
|
||||
@@ -83,8 +83,8 @@ The connection is created by calling `connect` on the context object:
|
||||
|
||||
```cpp
|
||||
// From gpu_ping_pong.cu, lines 76 and 82
|
||||
std::shared_ptr<mscclpp::Connection> conn0 = ctx->connect(/*localEndpoint*/ ep0, /*remoteEndpoint*/ ep1);
|
||||
std::shared_ptr<mscclpp::Connection> conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
|
||||
mscclpp::Connection conn0 = ctx->connect(/*localEndpoint*/ ep0, /*remoteEndpoint*/ ep1);
|
||||
mscclpp::Connection conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
|
||||
```
|
||||
|
||||
The `localEndpoint` and `remoteEndpoint` parameters specify which endpoints are used for the connection. A connection is asymmetric by nature, meaning that we need to create one connection for each endpoint. In this case, `conn0` is created for `ep0` to communicate with `ep1`, and `conn1` is created for `ep1` to communicate with `ep0`.
|
||||
@@ -101,7 +101,7 @@ sendToProcessB(serializedEp0); // send serializedEp0 to Process B using any IPC
|
||||
mscclpp::Endpoint ep1 = ctx->createEndpoint({transport, {mscclpp::DeviceType::GPU, 1}});
|
||||
std::vector<char> serializedEp0 = recvFromProcessA(); // receive serializedEp0 from Process A
|
||||
mscclpp::Endpoint ep0 = mscclpp::Endpoint::deserialize(serializedEp0);
|
||||
std::shared_ptr<mscclpp::Connection> conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
|
||||
mscclpp::Connection conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
|
||||
```
|
||||
|
||||
## SemaphoreStub and Semaphore
|
||||
|
||||
@@ -107,18 +107,18 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService_;
|
||||
int worldSize_;
|
||||
|
||||
void initialize(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
worldSize_ = comm->bootstrap()->getNranks();
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
if (i == comm->bootstrap()->getRank()) continue;
|
||||
connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i));
|
||||
}
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::vector<mscclpp::Connection> connections;
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const auto& future) { return future.get(); });
|
||||
this->conns_ = std::move(connections);
|
||||
|
||||
@@ -73,13 +73,13 @@ int main() {
|
||||
log("GPU 0: Creating a connection and a semaphore stub ...");
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(0));
|
||||
std::shared_ptr<mscclpp::Connection> conn0 = ctx->connect(/*localEndpoint*/ ep0, /*remoteEndpoint*/ ep1);
|
||||
mscclpp::Connection conn0 = ctx->connect(/*localEndpoint*/ ep0, /*remoteEndpoint*/ ep1);
|
||||
mscclpp::SemaphoreStub semaStub0(conn0);
|
||||
|
||||
log("GPU 1: Creating a connection and a semaphore stub ...");
|
||||
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(1));
|
||||
std::shared_ptr<mscclpp::Connection> conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
|
||||
mscclpp::Connection conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
|
||||
mscclpp::SemaphoreStub semaStub1(conn1);
|
||||
|
||||
log("GPU 0: Creating a semaphore and a memory channel ...");
|
||||
|
||||
@@ -425,6 +425,7 @@ struct EndpointConfig {
|
||||
|
||||
class Context;
|
||||
class Connection;
|
||||
class BaseConnection;
|
||||
class RegisteredMemory;
|
||||
class SemaphoreStub;
|
||||
class Semaphore;
|
||||
@@ -474,7 +475,7 @@ class Endpoint {
|
||||
std::shared_ptr<Impl> pimpl_;
|
||||
|
||||
friend class Context;
|
||||
friend class Connection;
|
||||
friend class BaseConnection;
|
||||
};
|
||||
|
||||
/// Context for communication. This provides a low-level interface for forming connections in use-cases
|
||||
@@ -521,8 +522,8 @@ class Context : public std::enable_shared_from_this<Context> {
|
||||
///
|
||||
/// @param localEndpoint The local endpoint.
|
||||
/// @param remoteEndpoint The remote endpoint.
|
||||
/// @return A shared pointer to the connection.
|
||||
std::shared_ptr<Connection> connect(const Endpoint& localEndpoint, const Endpoint& remoteEndpoint);
|
||||
/// @return A connection object.
|
||||
Connection connect(const Endpoint& localEndpoint, const Endpoint& remoteEndpoint);
|
||||
|
||||
private:
|
||||
Context();
|
||||
@@ -531,7 +532,7 @@ class Context : public std::enable_shared_from_this<Context> {
|
||||
std::unique_ptr<Impl> pimpl_;
|
||||
|
||||
friend class Endpoint;
|
||||
friend class Connection;
|
||||
friend class BaseConnection;
|
||||
friend class RegisteredMemory;
|
||||
friend class SemaphoreStub;
|
||||
};
|
||||
@@ -578,7 +579,7 @@ class RegisteredMemory {
|
||||
std::shared_ptr<Impl> pimpl_;
|
||||
|
||||
friend class Context;
|
||||
friend class Connection;
|
||||
friend class BaseConnection;
|
||||
friend class SemaphoreStub;
|
||||
friend class Semaphore;
|
||||
};
|
||||
@@ -587,12 +588,7 @@ class RegisteredMemory {
|
||||
class Connection {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param context The context associated with the connection.
|
||||
/// @param localEndpoint The local endpoint of the connection.
|
||||
Connection(std::shared_ptr<Context> context, const Endpoint& localEndpoint);
|
||||
|
||||
/// Destructor.
|
||||
virtual ~Connection() = default;
|
||||
Connection() = default;
|
||||
|
||||
/// Write data from a source RegisteredMemory to a destination RegisteredMemory.
|
||||
///
|
||||
@@ -601,8 +597,7 @@ class Connection {
|
||||
/// @param src The source RegisteredMemory.
|
||||
/// @param srcOffset The offset in bytes from the start of the source RegisteredMemory.
|
||||
/// @param size The number of bytes to write.
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) = 0;
|
||||
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
|
||||
|
||||
/// Update an 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process.
|
||||
///
|
||||
@@ -610,19 +605,19 @@ class Connection {
|
||||
/// @param dstOffset The offset in bytes from the start of the destination RegisteredMemory.
|
||||
/// @param src A pointer to the value to update.
|
||||
/// @param newValue The new value to write.
|
||||
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
|
||||
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue);
|
||||
|
||||
/// Flush any pending writes to the remote process.
|
||||
/// @param timeoutUsec Timeout in microseconds. Default: -1 (no timeout)
|
||||
virtual void flush(int64_t timeoutUsec = -1) = 0;
|
||||
void flush(int64_t timeoutUsec = -1);
|
||||
|
||||
/// Get the transport used by the local process.
|
||||
/// @return The transport used by the local process.
|
||||
virtual Transport transport() const = 0;
|
||||
Transport transport() const;
|
||||
|
||||
/// Get the transport used by the remote process.
|
||||
/// @return The transport used by the remote process.
|
||||
virtual Transport remoteTransport() const = 0;
|
||||
Transport remoteTransport() const;
|
||||
|
||||
/// Get the context associated with this connection.
|
||||
/// @return A shared pointer to the context associated with this connection.
|
||||
@@ -636,22 +631,23 @@ class Connection {
|
||||
/// @return The maximum number of write requests that can be queued.
|
||||
int getMaxWriteQueueSize() const;
|
||||
|
||||
protected:
|
||||
static const Endpoint::Impl& getImpl(const Endpoint& endpoint);
|
||||
static const RegisteredMemory::Impl& getImpl(const RegisteredMemory& memory);
|
||||
static Context::Impl& getImpl(Context& context);
|
||||
private:
|
||||
Connection(std::shared_ptr<BaseConnection> impl);
|
||||
std::shared_ptr<BaseConnection> impl_;
|
||||
|
||||
std::shared_ptr<Context> context_;
|
||||
Endpoint localEndpoint_;
|
||||
int maxWriteQueueSize_;
|
||||
friend class Context;
|
||||
friend class Communicator;
|
||||
friend class SemaphoreStub;
|
||||
friend class Semaphore;
|
||||
friend class ProxyService;
|
||||
};
|
||||
|
||||
/// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
|
||||
class SemaphoreStub {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// @param connection A shared pointer to the connection associated with this semaphore.
|
||||
SemaphoreStub(std::shared_ptr<Connection> connection);
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
SemaphoreStub(const Connection& connection);
|
||||
|
||||
/// Get the memory associated with this semaphore.
|
||||
/// @return A reference to the registered memory for this semaphore.
|
||||
@@ -686,8 +682,8 @@ class Semaphore {
|
||||
Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub);
|
||||
|
||||
/// Get the connection associated with this semaphore.
|
||||
/// @return A shared pointer to the connection.
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
/// @return The connection.
|
||||
Connection& connection();
|
||||
|
||||
/// Get the local memory associated with this semaphore.
|
||||
/// @return A reference to the local registered memory.
|
||||
@@ -873,34 +869,23 @@ class Communicator {
|
||||
/// @param localEndpoint The local endpoint.
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the send and receive.
|
||||
/// @return A future of shared pointer to the connection.
|
||||
/// @return A future of the connection.
|
||||
///
|
||||
std::shared_future<std::shared_ptr<Connection>> connect(const Endpoint& localEndpoint, int remoteRank, int tag = 0);
|
||||
std::shared_future<Connection> connect(const Endpoint& localEndpoint, int remoteRank, int tag = 0);
|
||||
|
||||
/// Connect to a remote rank. Wrapper of `connect(localEndpoint, remoteRank, tag)`.
|
||||
/// @param localConfig The configuration for the local endpoint.
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the send and receive.
|
||||
/// @return A future of shared pointer to the connection.
|
||||
std::shared_future<std::shared_ptr<Connection>> connect(const EndpointConfig& localConfig, int remoteRank,
|
||||
int tag = 0);
|
||||
|
||||
[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
|
||||
shared_future<std::shared_ptr<Connection>>
|
||||
connect(int remoteRank, int tag, EndpointConfig localConfig);
|
||||
|
||||
[[deprecated("Use connect() instead. This will be removed in a future release.")]] NonblockingFuture<
|
||||
std::shared_ptr<Connection>>
|
||||
connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return connect(localConfig, remoteRank, tag);
|
||||
}
|
||||
/// @return A future of the connection.
|
||||
std::shared_future<Connection> connect(const EndpointConfig& localConfig, int remoteRank, int tag = 0);
|
||||
|
||||
/// Build a semaphore for cross-process synchronization.
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the operation.
|
||||
/// @return A future of the built semaphore.
|
||||
std::shared_future<Semaphore> buildSemaphore(std::shared_ptr<Connection> connection, int remoteRank, int tag = 0);
|
||||
std::shared_future<Semaphore> buildSemaphore(const Connection& connection, int remoteRank, int tag = 0);
|
||||
|
||||
/// Get the remote rank a connection is connected to.
|
||||
///
|
||||
|
||||
@@ -33,7 +33,7 @@ class ProxyService : public BaseProxyService {
|
||||
/// Build and add a semaphore to the proxy service.
|
||||
/// @param connection The connection associated with the semaphore.
|
||||
/// @return The ID of the semaphore.
|
||||
SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
SemaphoreId buildAndAddSemaphore(Communicator& communicator, const Connection& connection);
|
||||
|
||||
/// Add a semaphore to the proxy service.
|
||||
/// @param semaphore The semaphore to be added
|
||||
@@ -83,7 +83,7 @@ class ProxyService : public BaseProxyService {
|
||||
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
|
||||
std::vector<RegisteredMemory> memories_;
|
||||
std::shared_ptr<Proxy> proxy_;
|
||||
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests_;
|
||||
std::unordered_map<std::shared_ptr<BaseConnection>, int> inflightRequests_;
|
||||
|
||||
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw);
|
||||
};
|
||||
|
||||
@@ -27,11 +27,11 @@ class Host2DeviceSemaphore {
|
||||
/// Constructor.
|
||||
/// @param communicator The communicator.
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
Host2DeviceSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
Host2DeviceSemaphore(Communicator& communicator, const Connection& connection);
|
||||
|
||||
/// Returns the connection.
|
||||
/// @return The connection associated with this semaphore.
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
Connection& connection();
|
||||
|
||||
/// Signal the device.
|
||||
void signal();
|
||||
@@ -59,11 +59,11 @@ class Host2HostSemaphore {
|
||||
/// @param communicator The communicator.
|
||||
/// @param connection The connection associated with this semaphore. Transport::CudaIpc is not allowed for
|
||||
/// Host2HostSemaphore.
|
||||
Host2HostSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
Host2HostSemaphore(Communicator& communicator, const Connection& connection);
|
||||
|
||||
/// Returns the connection.
|
||||
/// @return The connection associated with this semaphore.
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
Connection& connection();
|
||||
|
||||
/// Signal the remote host.
|
||||
void signal();
|
||||
@@ -92,11 +92,11 @@ class MemoryDevice2DeviceSemaphore {
|
||||
/// Constructor.
|
||||
/// @param communicator The communicator.
|
||||
/// @param connection The connection associated with this semaphore.
|
||||
MemoryDevice2DeviceSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
|
||||
MemoryDevice2DeviceSemaphore(Communicator& communicator, const Connection& connection);
|
||||
|
||||
/// Returns the connection.
|
||||
/// @return The connection associated with this semaphore.
|
||||
std::shared_ptr<Connection> connection() const;
|
||||
Connection& connection();
|
||||
|
||||
/// Device-side handle for MemoryDevice2DeviceSemaphore.
|
||||
using DeviceHandle = MemoryDevice2DeviceSemaphoreDeviceHandle;
|
||||
|
||||
@@ -202,7 +202,7 @@ void register_core(nb::module_& m) {
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
|
||||
.def(nb::init<std::shared_ptr<Connection>>(), nb::arg("connection"))
|
||||
.def(nb::init<const Connection&>(), nb::arg("connection"))
|
||||
.def("memory", &SemaphoreStub::memory)
|
||||
.def("serialize", &SemaphoreStub::serialize)
|
||||
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
|
||||
@@ -215,7 +215,7 @@ void register_core(nb::module_& m) {
|
||||
.def("remote_memory", &Semaphore::remoteMemory);
|
||||
|
||||
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_shared_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
def_shared_future<Connection>(m, "Connection");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
@@ -231,8 +231,8 @@ void register_core(nb::module_& m) {
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag") = 0)
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag") = 0)
|
||||
.def("connect",
|
||||
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const EndpointConfig&, int,
|
||||
int)>(&Communicator::connect),
|
||||
static_cast<std::shared_future<Connection> (Communicator::*)(const EndpointConfig&, int, int)>(
|
||||
&Communicator::connect),
|
||||
nb::arg("local_config"), nb::arg("remote_rank"), nb::arg("tag") = 0)
|
||||
.def(
|
||||
"connect_on_setup",
|
||||
|
||||
@@ -12,7 +12,7 @@ using namespace mscclpp;
|
||||
void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
|
||||
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2DeviceSemaphore::connection)
|
||||
.def("signal", &Host2DeviceSemaphore::signal)
|
||||
.def("device_handle", &Host2DeviceSemaphore::deviceHandle);
|
||||
@@ -27,7 +27,7 @@ void register_semaphore(nb::module_& m) {
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
.def("signal", &Host2HostSemaphore::signal)
|
||||
.def("poll", &Host2HostSemaphore::poll)
|
||||
@@ -36,7 +36,7 @@ void register_semaphore(nb::module_& m) {
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
|
||||
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
|
||||
.def("device_handle", &MemoryDevice2DeviceSemaphore::deviceHandle);
|
||||
|
||||
|
||||
@@ -21,13 +21,13 @@ class MyProxyService {
|
||||
private:
|
||||
int deviceNumaNode_;
|
||||
int my_rank_, nranks_, dataSize_;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections_;
|
||||
std::vector<mscclpp::Connection> connections_;
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
|
||||
mscclpp::Proxy proxy_;
|
||||
|
||||
public:
|
||||
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<std::shared_ptr<mscclpp::Connection>> conns,
|
||||
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<mscclpp::Connection> conns,
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem,
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores)
|
||||
: my_rank_(my_rank),
|
||||
@@ -46,10 +46,10 @@ class MyProxyService {
|
||||
int dataSizePerRank = dataSize_ / nranks_;
|
||||
for (int r = 1; r < nranks_; ++r) {
|
||||
int nghr = (my_rank_ + r) % nranks_;
|
||||
connections_[nghr]->write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
|
||||
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
|
||||
connections_[nghr].write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
|
||||
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
|
||||
semaphores_[nghr]->signal();
|
||||
connections_[nghr]->flush();
|
||||
connections_[nghr].flush();
|
||||
}
|
||||
return mscclpp::ProxyHandlerResult::Continue;
|
||||
}
|
||||
@@ -63,7 +63,7 @@ class MyProxyService {
|
||||
|
||||
void init_mscclpp_proxy_test_module(nb::module_ &m) {
|
||||
nb::class_<MyProxyService>(m, "MyProxyService")
|
||||
.def(nb::init<int, int, int, std::vector<std::shared_ptr<mscclpp::Connection>>,
|
||||
.def(nb::init<int, int, int, std::vector<mscclpp::Connection>,
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>>,
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>>>(),
|
||||
nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"),
|
||||
|
||||
@@ -99,16 +99,16 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const Endpoint& localEndpoint,
|
||||
int remoteRank, int tag) {
|
||||
MSCCLPP_API_CPP std::shared_future<Connection> Communicator::connect(const Endpoint& localEndpoint, int remoteRank,
|
||||
int tag) {
|
||||
if (remoteRank == bootstrap()->getRank()) {
|
||||
// Connection to self
|
||||
auto remoteEndpoint = context()->createEndpoint(localEndpoint.config());
|
||||
auto connection = context()->connect(localEndpoint, remoteEndpoint);
|
||||
std::promise<std::shared_ptr<Connection>> promise;
|
||||
std::promise<Connection> promise;
|
||||
promise.set_value(connection);
|
||||
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
|
||||
return std::shared_future<std::shared_ptr<Connection>>(promise.get_future());
|
||||
pimpl_->connectionInfos_[connection.impl_.get()] = {remoteRank, tag};
|
||||
return std::shared_future<Connection>(promise.get_future());
|
||||
}
|
||||
|
||||
bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);
|
||||
@@ -123,27 +123,22 @@ MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::co
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
auto remoteEndpoint = Endpoint::deserialize(data);
|
||||
auto connection = context()->connect(localEndpoint, remoteEndpoint);
|
||||
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
|
||||
pimpl_->connectionInfos_[connection.impl_.get()] = {remoteRank, tag};
|
||||
return connection;
|
||||
});
|
||||
auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move(future));
|
||||
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<std::shared_ptr<Connection>>>(shared_future));
|
||||
auto shared_future = std::shared_future<Connection>(std::move(future));
|
||||
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<Connection>>(shared_future));
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const EndpointConfig& localConfig,
|
||||
int remoteRank, int tag) {
|
||||
MSCCLPP_API_CPP std::shared_future<Connection> Communicator::connect(const EndpointConfig& localConfig, int remoteRank,
|
||||
int tag) {
|
||||
auto localEndpoint = context()->createEndpoint(localConfig);
|
||||
return connect(localEndpoint, remoteRank, tag);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
|
||||
EndpointConfig localConfig) {
|
||||
return connect(localConfig, remoteRank, tag);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<Semaphore> Communicator::buildSemaphore(std::shared_ptr<Connection> connection,
|
||||
int remoteRank, int tag) {
|
||||
MSCCLPP_API_CPP std::shared_future<Semaphore> Communicator::buildSemaphore(const Connection& connection, int remoteRank,
|
||||
int tag) {
|
||||
SemaphoreStub localStub(connection);
|
||||
bootstrap()->send(localStub.serialize(), remoteRank, tag);
|
||||
|
||||
@@ -165,11 +160,11 @@ MSCCLPP_API_CPP std::shared_future<Semaphore> Communicator::buildSemaphore(std::
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
|
||||
return pimpl_->connectionInfos_.at(&connection).remoteRank;
|
||||
return pimpl_->connectionInfos_.at(connection.impl_.get()).remoteRank;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::tagOf(const Connection& connection) {
|
||||
return pimpl_->connectionInfos_.at(&connection).tag;
|
||||
return pimpl_->connectionInfos_.at(connection.impl_.get()).tag;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -32,28 +32,54 @@ static bool isSameProcess(const Endpoint& a, const Endpoint& b) {
|
||||
return a.hostHash() == b.hostHash() && a.pidHash() == b.pidHash();
|
||||
}
|
||||
|
||||
// Connection
|
||||
// BaseConnection
|
||||
|
||||
const Endpoint::Impl& Connection::getImpl(const Endpoint& endpoint) { return *(endpoint.pimpl_); }
|
||||
const Endpoint::Impl& BaseConnection::getImpl(const Endpoint& endpoint) { return *(endpoint.pimpl_); }
|
||||
|
||||
const RegisteredMemory::Impl& Connection::getImpl(const RegisteredMemory& memory) { return *(memory.pimpl_); }
|
||||
const RegisteredMemory::Impl& BaseConnection::getImpl(const RegisteredMemory& memory) { return *(memory.pimpl_); }
|
||||
|
||||
Context::Impl& Connection::getImpl(Context& context) { return *(context.pimpl_); }
|
||||
Context::Impl& BaseConnection::getImpl(Context& context) { return *(context.pimpl_); }
|
||||
|
||||
MSCCLPP_API_CPP Connection::Connection(std::shared_ptr<Context> context, const Endpoint& localEndpoint)
|
||||
MSCCLPP_API_CPP BaseConnection::BaseConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint)
|
||||
: context_(context), localEndpoint_(localEndpoint), maxWriteQueueSize_(localEndpoint.maxWriteQueueSize()) {}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Context> Connection::context() const { return context_; }
|
||||
MSCCLPP_API_CPP std::shared_ptr<Context> BaseConnection::context() const { return context_; }
|
||||
|
||||
MSCCLPP_API_CPP const Device& Connection::localDevice() const { return localEndpoint_.device(); }
|
||||
MSCCLPP_API_CPP const Device& BaseConnection::localDevice() const { return localEndpoint_.device(); }
|
||||
|
||||
MSCCLPP_API_CPP int Connection::getMaxWriteQueueSize() const { return maxWriteQueueSize_; }
|
||||
MSCCLPP_API_CPP int BaseConnection::getMaxWriteQueueSize() const { return maxWriteQueueSize_; }
|
||||
|
||||
// Connection wrapper
|
||||
|
||||
Connection::Connection(std::shared_ptr<BaseConnection> impl) : impl_(impl) {}
|
||||
|
||||
MSCCLPP_API_CPP void Connection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src,
|
||||
uint64_t srcOffset, uint64_t size) {
|
||||
impl_->write(dst, dstOffset, src, srcOffset, size);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Connection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src,
|
||||
uint64_t newValue) {
|
||||
impl_->updateAndSync(dst, dstOffset, src, newValue);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Connection::flush(int64_t timeoutUsec) { impl_->flush(timeoutUsec); }
|
||||
|
||||
MSCCLPP_API_CPP Transport Connection::transport() const { return impl_->transport(); }
|
||||
|
||||
MSCCLPP_API_CPP Transport Connection::remoteTransport() const { return impl_->remoteTransport(); }
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Context> Connection::context() const { return impl_->context(); }
|
||||
|
||||
MSCCLPP_API_CPP const Device& Connection::localDevice() const { return impl_->localDevice(); }
|
||||
|
||||
MSCCLPP_API_CPP int Connection::getMaxWriteQueueSize() const { return impl_->getMaxWriteQueueSize(); }
|
||||
|
||||
// CudaIpcConnection
|
||||
|
||||
CudaIpcConnection::CudaIpcConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint,
|
||||
const Endpoint& remoteEndpoint)
|
||||
: Connection(context, localEndpoint) {
|
||||
: BaseConnection(context, localEndpoint) {
|
||||
if (localEndpoint.transport() != Transport::CudaIpc || remoteEndpoint.transport() != Transport::CudaIpc) {
|
||||
THROW(CONN, Error, ErrorCode::InternalError, "CudaIpc transport is required for CudaIpcConnection");
|
||||
}
|
||||
@@ -163,7 +189,7 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
|
||||
|
||||
IBConnection::IBConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint,
|
||||
const Endpoint& remoteEndpoint)
|
||||
: Connection(context, localEndpoint),
|
||||
: BaseConnection(context, localEndpoint),
|
||||
transport_(localEndpoint.transport()),
|
||||
remoteTransport_(remoteEndpoint.transport()),
|
||||
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
|
||||
@@ -276,7 +302,7 @@ void IBConnection::flush(int64_t timeoutUsec) {
|
||||
|
||||
EthernetConnection::EthernetConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint,
|
||||
const Endpoint& remoteEndpoint, uint64_t sendBufferSize, uint64_t recvBufferSize)
|
||||
: Connection(context, localEndpoint),
|
||||
: BaseConnection(context, localEndpoint),
|
||||
abortFlag_(0),
|
||||
sendBufferSize_(sendBufferSize),
|
||||
recvBufferSize_(recvBufferSize) {
|
||||
|
||||
@@ -76,8 +76,7 @@ MSCCLPP_API_CPP Endpoint Context::createEndpoint(EndpointConfig config) {
|
||||
return Endpoint(std::make_shared<Endpoint::Impl>(config, *pimpl_));
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &localEndpoint,
|
||||
const Endpoint &remoteEndpoint) {
|
||||
MSCCLPP_API_CPP Connection Context::connect(const Endpoint &localEndpoint, const Endpoint &remoteEndpoint) {
|
||||
if (localEndpoint.device().type == DeviceType::GPU && localEndpoint.device().id < 0) {
|
||||
throw Error("No GPU device ID provided for local endpoint", ErrorCode::InvalidUsage);
|
||||
}
|
||||
@@ -93,7 +92,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &loc
|
||||
<< std::to_string(remoteEndpoint.transport()) << ") endpoints";
|
||||
throw Error(ss.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
std::shared_ptr<Connection> conn;
|
||||
std::shared_ptr<BaseConnection> conn;
|
||||
if (localTransport == Transport::CudaIpc) {
|
||||
conn = std::make_shared<CudaIpcConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
|
||||
} else if (AllIBTransports.has(localTransport)) {
|
||||
@@ -103,7 +102,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &loc
|
||||
} else {
|
||||
throw Error("Unsupported transport", ErrorCode::InternalError);
|
||||
}
|
||||
return conn;
|
||||
return Connection(conn);
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -112,7 +112,7 @@ namespace mscclpp {
|
||||
|
||||
struct ExecutionContext {
|
||||
std::shared_ptr<ProxyService> proxyService;
|
||||
std::unordered_map<int, std::shared_ptr<Connection>> connections;
|
||||
std::unordered_map<int, Connection> connections;
|
||||
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
|
||||
MemoryId localMemoryIdBegin = MemoryId(0);
|
||||
|
||||
@@ -270,7 +270,7 @@ struct Executor::Impl {
|
||||
};
|
||||
|
||||
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers();
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
for (int peer : connectedPeers) {
|
||||
Transport transport =
|
||||
inSameNode(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
|
||||
@@ -339,10 +339,10 @@ struct Executor::Impl {
|
||||
auto connection = context.connections.at(peer);
|
||||
if (info.channelType == ChannelType::MEMORY) {
|
||||
futureMemorySemaphores.push_back(this->comm->buildSemaphore(
|
||||
connection, this->comm->remoteRankOf(*connection), this->comm->tagOf(*connection)));
|
||||
connection, this->comm->remoteRankOf(connection), this->comm->tagOf(connection)));
|
||||
} else if (info.channelType == ChannelType::PORT) {
|
||||
futureProxySemaphores.push_back(this->comm->buildSemaphore(
|
||||
connection, this->comm->remoteRankOf(*connection), this->comm->tagOf(*connection)));
|
||||
futureProxySemaphores.push_back(this->comm->buildSemaphore(connection, this->comm->remoteRankOf(connection),
|
||||
this->comm->tagOf(connection)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ struct ConnectionInfo {
|
||||
struct Communicator::Impl {
|
||||
std::shared_ptr<Bootstrap> bootstrap_;
|
||||
std::shared_ptr<Context> context_;
|
||||
std::unordered_map<const Connection*, ConnectionInfo> connectionInfos_;
|
||||
std::unordered_map<const BaseConnection*, ConnectionInfo> connectionInfos_;
|
||||
|
||||
// Temporary storage for the latest RecvItem of each {remoteRank, tag} pair.
|
||||
// If the RecvItem gets ready, it will be removed at the next call to getLastRecvItem.
|
||||
|
||||
@@ -15,7 +15,46 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
class CudaIpcConnection : public Connection {
|
||||
/// Internal base class for connection implementations between two processes.
|
||||
class BaseConnection {
|
||||
public:
|
||||
BaseConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint);
|
||||
|
||||
virtual ~BaseConnection() = default;
|
||||
|
||||
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
|
||||
uint64_t size) = 0;
|
||||
|
||||
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
|
||||
|
||||
virtual void flush(int64_t timeoutUsec = -1) = 0;
|
||||
|
||||
virtual Transport transport() const = 0;
|
||||
|
||||
virtual Transport remoteTransport() const = 0;
|
||||
|
||||
std::shared_ptr<Context> context() const;
|
||||
|
||||
const Device& localDevice() const;
|
||||
|
||||
int getMaxWriteQueueSize() const;
|
||||
|
||||
protected:
|
||||
friend class Context;
|
||||
friend class CudaIpcConnection;
|
||||
friend class IBConnection;
|
||||
friend class EthernetConnection;
|
||||
|
||||
static const Endpoint::Impl& getImpl(const Endpoint& endpoint);
|
||||
static const RegisteredMemory::Impl& getImpl(const RegisteredMemory& memory);
|
||||
static Context::Impl& getImpl(Context& context);
|
||||
|
||||
std::shared_ptr<Context> context_;
|
||||
Endpoint localEndpoint_;
|
||||
int maxWriteQueueSize_;
|
||||
};
|
||||
|
||||
class CudaIpcConnection : public BaseConnection {
|
||||
private:
|
||||
std::shared_ptr<CudaIpcStream> stream_;
|
||||
|
||||
@@ -33,7 +72,7 @@ class CudaIpcConnection : public Connection {
|
||||
void flush(int64_t timeoutUsec) override;
|
||||
};
|
||||
|
||||
class IBConnection : public Connection {
|
||||
class IBConnection : public BaseConnection {
|
||||
private:
|
||||
Transport transport_;
|
||||
Transport remoteTransport_;
|
||||
@@ -56,7 +95,7 @@ class IBConnection : public Connection {
|
||||
void flush(int64_t timeoutUsec) override;
|
||||
};
|
||||
|
||||
class EthernetConnection : public Connection {
|
||||
class EthernetConnection : public BaseConnection {
|
||||
private:
|
||||
std::unique_ptr<Socket> sendSocket_;
|
||||
std::unique_ptr<Socket> recvSocket_;
|
||||
|
||||
@@ -42,7 +42,7 @@ MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) {
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator,
|
||||
std::shared_ptr<Connection> connection) {
|
||||
const Connection& connection) {
|
||||
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator, connection));
|
||||
return semaphores_.size() - 1;
|
||||
}
|
||||
@@ -89,13 +89,14 @@ MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }
|
||||
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
|
||||
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger.fields.semaphoreId];
|
||||
|
||||
int maxWriteQueueSize = semaphore->connection()->getMaxWriteQueueSize();
|
||||
auto& numRequests = inflightRequests_[semaphore->connection()];
|
||||
auto& conn = semaphore->connection();
|
||||
int maxWriteQueueSize = conn.getMaxWriteQueueSize();
|
||||
auto& numRequests = inflightRequests_[conn.impl_];
|
||||
|
||||
if (trigger.fields.type & TriggerData) {
|
||||
RegisteredMemory& dst = memories_[trigger.fields.dstMemoryId];
|
||||
RegisteredMemory& src = memories_[trigger.fields.srcMemoryId];
|
||||
semaphore->connection()->write(dst, trigger.fields.dstOffset, src, trigger.fields.srcOffset, trigger.fields.size);
|
||||
conn.write(dst, trigger.fields.dstOffset, src, trigger.fields.srcOffset, trigger.fields.size);
|
||||
numRequests++;
|
||||
}
|
||||
|
||||
@@ -106,7 +107,7 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
|
||||
|
||||
if (((trigger.fields.type & TriggerSync) && numRequests > 0) ||
|
||||
(maxWriteQueueSize != -1 && numRequests > maxWriteQueueSize)) {
|
||||
semaphore->connection()->flush();
|
||||
conn.flush();
|
||||
numRequests = 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include "api.h"
|
||||
#include "atomic.hpp"
|
||||
#include "connection.hpp"
|
||||
#include "context.hpp"
|
||||
#include "debug.h"
|
||||
#include "registered_memory.hpp"
|
||||
@@ -14,7 +15,7 @@
|
||||
namespace mscclpp {
|
||||
|
||||
struct SemaphoreStub::Impl {
|
||||
Impl(std::shared_ptr<Connection> connection);
|
||||
Impl(const Connection& connection);
|
||||
|
||||
Impl(const RegisteredMemory& idMemory, const Device& device);
|
||||
|
||||
@@ -22,7 +23,7 @@ struct SemaphoreStub::Impl {
|
||||
|
||||
std::shared_ptr<uint64_t> gpuCallocToken(std::shared_ptr<Context> context);
|
||||
|
||||
std::shared_ptr<Connection> connection_;
|
||||
Connection connection_;
|
||||
std::shared_ptr<uint64_t> token_;
|
||||
RegisteredMemory idMemory_;
|
||||
Device device_;
|
||||
@@ -41,9 +42,9 @@ std::shared_ptr<uint64_t> SemaphoreStub::Impl::gpuCallocToken(std::shared_ptr<Co
|
||||
#endif // !defined(MSCCLPP_DEVICE_HIP)
|
||||
}
|
||||
|
||||
SemaphoreStub::Impl::Impl(std::shared_ptr<Connection> connection) : connection_(connection) {
|
||||
SemaphoreStub::Impl::Impl(const Connection& connection) : connection_(connection) {
|
||||
// Allocate a semaphore ID on the local device
|
||||
const Device& localDevice = connection_->localDevice();
|
||||
const Device& localDevice = connection_.localDevice();
|
||||
if (localDevice.type == DeviceType::CPU) {
|
||||
token_ = std::make_shared<uint64_t>(0);
|
||||
} else if (localDevice.type == DeviceType::GPU) {
|
||||
@@ -51,12 +52,11 @@ SemaphoreStub::Impl::Impl(std::shared_ptr<Connection> connection) : connection_(
|
||||
throw Error("Local GPU ID is not provided", ErrorCode::InvalidUsage);
|
||||
}
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(localDevice.id));
|
||||
token_ = gpuCallocToken(connection_->context());
|
||||
token_ = gpuCallocToken(connection_.context());
|
||||
} else {
|
||||
throw Error("Unsupported local device type", ErrorCode::InvalidUsage);
|
||||
}
|
||||
idMemory_ =
|
||||
std::move(connection->context()->registerMemory(token_.get(), sizeof(uint64_t), connection_->transport()));
|
||||
idMemory_ = std::move(connection_.context()->registerMemory(token_.get(), sizeof(uint64_t), connection_.transport()));
|
||||
}
|
||||
|
||||
SemaphoreStub::Impl::Impl(const RegisteredMemory& idMemory, const Device& device)
|
||||
@@ -64,7 +64,7 @@ SemaphoreStub::Impl::Impl(const RegisteredMemory& idMemory, const Device& device
|
||||
|
||||
SemaphoreStub::SemaphoreStub(std::shared_ptr<Impl> pimpl) : pimpl_(std::move(pimpl)) {}
|
||||
|
||||
MSCCLPP_API_CPP SemaphoreStub::SemaphoreStub(std::shared_ptr<Connection> connection)
|
||||
MSCCLPP_API_CPP SemaphoreStub::SemaphoreStub(const Connection& connection)
|
||||
: pimpl_(std::make_shared<Impl>(connection)) {}
|
||||
|
||||
MSCCLPP_API_CPP std::vector<char> SemaphoreStub::serialize() const {
|
||||
@@ -103,17 +103,15 @@ Semaphore::Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remote
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Semaphore::connection() const {
|
||||
return pimpl_->localStub_.pimpl_->connection_;
|
||||
}
|
||||
MSCCLPP_API_CPP Connection& Semaphore::connection() { return pimpl_->localStub_.pimpl_->connection_; }
|
||||
|
||||
MSCCLPP_API_CPP const RegisteredMemory& Semaphore::localMemory() const { return pimpl_->localStub_.memory(); }
|
||||
|
||||
MSCCLPP_API_CPP const RegisteredMemory& Semaphore::remoteMemory() const { return pimpl_->remoteStubMemory_; }
|
||||
|
||||
static Semaphore buildSemaphoreFromConnection(Communicator& communicator, std::shared_ptr<Connection> connection) {
|
||||
static Semaphore buildSemaphoreFromConnection(Communicator& communicator, const Connection& connection) {
|
||||
auto semaphoreFuture =
|
||||
communicator.buildSemaphore(connection, communicator.remoteRankOf(*connection), communicator.tagOf(*connection));
|
||||
communicator.buildSemaphore(connection, communicator.remoteRankOf(connection), communicator.tagOf(connection));
|
||||
return semaphoreFuture.get();
|
||||
}
|
||||
|
||||
@@ -121,19 +119,18 @@ MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(const Semaphore& sema
|
||||
: semaphore_(semaphore),
|
||||
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()),
|
||||
outboundToken_(std::make_unique<uint64_t>()) {
|
||||
if (connection()->localDevice().type != DeviceType::GPU) {
|
||||
if (connection().localDevice().type != DeviceType::GPU) {
|
||||
throw Error("Local endpoint device type of Host2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator,
|
||||
std::shared_ptr<Connection> connection)
|
||||
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator, const Connection& connection)
|
||||
: Host2DeviceSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2DeviceSemaphore::connection() const { return semaphore_.connection(); }
|
||||
MSCCLPP_API_CPP Connection& Host2DeviceSemaphore::connection() { return semaphore_.connection(); }
|
||||
|
||||
MSCCLPP_API_CPP void Host2DeviceSemaphore::signal() {
|
||||
connection()->updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
|
||||
connection().updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Host2DeviceSemaphore::DeviceHandle Host2DeviceSemaphore::deviceHandle() const {
|
||||
@@ -147,22 +144,21 @@ MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(const Semaphore& semaphor
|
||||
: semaphore_(semaphore),
|
||||
expectedInboundToken_(std::make_unique<uint64_t>()),
|
||||
outboundToken_(std::make_unique<uint64_t>()) {
|
||||
if (connection()->transport() == Transport::CudaIpc) {
|
||||
if (connection().transport() == Transport::CudaIpc) {
|
||||
throw Error("Host2HostSemaphore cannot be used with CudaIpc transport", ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (connection()->localDevice().type != DeviceType::CPU) {
|
||||
if (connection().localDevice().type != DeviceType::CPU) {
|
||||
throw Error("Local endpoint device type of Host2HostSemaphore should be CPU", ErrorCode::InvalidUsage);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(Communicator& communicator,
|
||||
std::shared_ptr<Connection> connection)
|
||||
MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(Communicator& communicator, const Connection& connection)
|
||||
: Host2HostSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2HostSemaphore::connection() const { return semaphore_.connection(); }
|
||||
MSCCLPP_API_CPP Connection& Host2HostSemaphore::connection() { return semaphore_.connection(); }
|
||||
|
||||
MSCCLPP_API_CPP void Host2HostSemaphore::signal() {
|
||||
connection()->updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
|
||||
connection().updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP bool Host2HostSemaphore::poll() {
|
||||
@@ -187,18 +183,16 @@ MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(const
|
||||
: semaphore_(semaphore),
|
||||
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()),
|
||||
outboundToken_(detail::gpuCallocUnique<uint64_t>()) {
|
||||
if (connection()->localDevice().type != DeviceType::GPU) {
|
||||
if (connection().localDevice().type != DeviceType::GPU) {
|
||||
throw Error("Local endpoint device type of MemoryDevice2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(Communicator& communicator,
|
||||
std::shared_ptr<Connection> connection)
|
||||
const Connection& connection)
|
||||
: MemoryDevice2DeviceSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> MemoryDevice2DeviceSemaphore::connection() const {
|
||||
return semaphore_.connection();
|
||||
}
|
||||
MSCCLPP_API_CPP Connection& MemoryDevice2DeviceSemaphore::connection() { return semaphore_.connection(); }
|
||||
|
||||
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::DeviceHandle MemoryDevice2DeviceSemaphore::deviceHandle() const {
|
||||
MemoryDevice2DeviceSemaphore::DeviceHandle device;
|
||||
|
||||
@@ -213,7 +213,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
|
||||
std::vector<mscclpp::SemaphoreId> semaphoreIds;
|
||||
std::vector<mscclpp::RegisteredMemory> localMemories;
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connections(world_size);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemories;
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
|
||||
@@ -85,7 +85,7 @@ class MyProxyService {
|
||||
mscclpp::RegisteredMemory localMemory_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> deviceSemaphores1_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> deviceSemaphores2_;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections_;
|
||||
std::vector<mscclpp::Connection> connections_;
|
||||
mscclpp::Proxy proxy_;
|
||||
|
||||
public:
|
||||
@@ -98,7 +98,7 @@ class MyProxyService {
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
|
||||
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionsFuture(world_size);
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionsFuture(world_size);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemoriesFuture(world_size);
|
||||
|
||||
localMemory_ = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
@@ -138,15 +138,15 @@ class MyProxyService {
|
||||
int dataSizePerRank = dataSize_ / world_size;
|
||||
for (int r = 1; r < world_size; ++r) {
|
||||
int nghr = (rank + r) % world_size;
|
||||
connections_[nghr]->write(remoteMemories_[nghr], rank * dataSizePerRank, localMemory_, rank * dataSizePerRank,
|
||||
dataSizePerRank);
|
||||
connections_[nghr].write(remoteMemories_[nghr], rank * dataSizePerRank, localMemory_, rank * dataSizePerRank,
|
||||
dataSizePerRank);
|
||||
if (triggerRaw.fst == 1)
|
||||
deviceSemaphores1_[nghr]->signal();
|
||||
else
|
||||
deviceSemaphores2_[nghr]->signal();
|
||||
if ((flusher % 64) == 0 && mscclpp::AllIBTransports.has(connections_[nghr]->transport())) {
|
||||
if ((flusher % 64) == 0 && mscclpp::AllIBTransports.has(connections_[nghr].transport())) {
|
||||
// if we are using IB transport, we need a flush every once in a while
|
||||
connections_[nghr]->flush();
|
||||
connections_[nghr].flush();
|
||||
}
|
||||
}
|
||||
flusher++;
|
||||
|
||||
@@ -44,8 +44,8 @@ void CommunicatorTestBase::TearDown() {
|
||||
void CommunicatorTestBase::setNumRanksToUse(int num) { numRanksToUse = num; }
|
||||
|
||||
void CommunicatorTestBase::connectMesh(bool useIpc, bool useIb, bool useEthernet) {
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(numRanksToUse);
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> cpuConnectionFutures(numRanksToUse);
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures(numRanksToUse);
|
||||
std::vector<std::shared_future<mscclpp::Connection>> cpuConnectionFutures(numRanksToUse);
|
||||
for (int i = 0; i < numRanksToUse; i++) {
|
||||
if (i != gEnv->rank) {
|
||||
if ((rankToNode(i) == rankToNode(gEnv->rank)) && useIpc) {
|
||||
@@ -158,9 +158,9 @@ void CommunicatorTest::writeToRemote(int dataCountPerRank) {
|
||||
if (i != gEnv->rank) {
|
||||
auto& conn = connections.at(i);
|
||||
auto& peerMemory = remoteMemory[n].at(i);
|
||||
conn->write(peerMemory, gEnv->rank * dataCountPerRank * sizeof(int), localMemory[n],
|
||||
gEnv->rank * dataCountPerRank * sizeof(int), dataCountPerRank * sizeof(int));
|
||||
conn->flush();
|
||||
conn.write(peerMemory, gEnv->rank * dataCountPerRank * sizeof(int), localMemory[n],
|
||||
gEnv->rank * dataCountPerRank * sizeof(int), dataCountPerRank * sizeof(int));
|
||||
conn.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,7 +262,7 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
|
||||
for (auto entry : cpuConnections) {
|
||||
auto& conn = entry.second;
|
||||
// Host2HostSemaphore cannot be used with CudaIpc transport
|
||||
if (conn->transport() == mscclpp::Transport::CudaIpc) continue;
|
||||
if (conn.transport() == mscclpp::Transport::CudaIpc) continue;
|
||||
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2HostSemaphore>(*communicator.get(), conn)});
|
||||
}
|
||||
communicator->bootstrap()->barrier();
|
||||
@@ -273,25 +273,25 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
|
||||
writeToRemote(deviceBufferSize / sizeof(int) / gEnv->worldSize);
|
||||
|
||||
for (int i = 0; i < gEnv->worldSize; i++) {
|
||||
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
|
||||
if (i != gEnv->rank && connections[i].transport() != mscclpp::Transport::CudaIpc) {
|
||||
semaphores[i]->signal();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gEnv->worldSize; i++) {
|
||||
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
|
||||
if (i != gEnv->rank && connections[i].transport() != mscclpp::Transport::CudaIpc) {
|
||||
semaphores[i]->wait();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gEnv->worldSize; i++) {
|
||||
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
|
||||
if (i != gEnv->rank && connections[i].transport() != mscclpp::Transport::CudaIpc) {
|
||||
semaphores[i]->signal();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gEnv->worldSize; i++) {
|
||||
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
|
||||
if (i != gEnv->rank && connections[i].transport() != mscclpp::Transport::CudaIpc) {
|
||||
semaphores[i]->wait();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ void MemoryChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::Memory
|
||||
const bool isInPlace = (outputBuff == nullptr);
|
||||
mscclpp::TransportFlags transport = mscclpp::Transport::CudaIpc;
|
||||
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures(worldSize);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
|
||||
|
||||
mscclpp::RegisteredMemory inputBufRegMem = communicator->registerMemory(inputBuff, inputBuffBytes, transport);
|
||||
|
||||
@@ -108,8 +108,8 @@ class CommunicatorTestBase : public MultiProcessTest {
|
||||
std::shared_ptr<mscclpp::Communicator> communicator;
|
||||
mscclpp::Transport ibTransport;
|
||||
std::vector<mscclpp::RegisteredMemory> registeredMemories;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cpuConnections;
|
||||
std::unordered_map<int, mscclpp::Connection> connections;
|
||||
std::unordered_map<int, mscclpp::Connection> cpuConnections;
|
||||
};
|
||||
|
||||
class CommunicatorTest : public CommunicatorTestBase {
|
||||
|
||||
@@ -28,7 +28,7 @@ void PortChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::PortChan
|
||||
if (useIb) transport |= ibTransport;
|
||||
if (useEthernet) transport |= mscclpp::Transport::Ethernet;
|
||||
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures(worldSize);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
|
||||
|
||||
mscclpp::RegisteredMemory sendBufRegMem = communicator->registerMemory(sendBuff, sendBuffBytes, transport);
|
||||
|
||||
@@ -509,7 +509,7 @@ class AllGatherProxyService : public mscclpp::BaseProxyService {
|
||||
void addRemoteMemory(mscclpp::RegisteredMemory memory) { remoteMemories_.push_back(memory); }
|
||||
void setLocalMemory(mscclpp::RegisteredMemory memory) { localMemory_ = memory; }
|
||||
mscclpp::SemaphoreId buildAndAddSemaphore(mscclpp::Communicator& communicator,
|
||||
std::shared_ptr<mscclpp::Connection> connection) {
|
||||
const mscclpp::Connection& connection) {
|
||||
semaphores_.push_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(communicator, connection));
|
||||
return semaphores_.size() - 1;
|
||||
}
|
||||
@@ -554,19 +554,19 @@ mscclpp::ProxyHandlerResult AllGatherProxyService::handleTrigger(mscclpp::ProxyT
|
||||
continue;
|
||||
}
|
||||
int index = (r < rank_) ? r : r - 1;
|
||||
semaphores_[index]->connection()->write(remoteMemories_[index], offset, localMemory_, offset, sendBytes_);
|
||||
semaphores_[index]->connection().write(remoteMemories_[index], offset, localMemory_, offset, sendBytes_);
|
||||
semaphores_[index]->signal();
|
||||
}
|
||||
bool flushIpc = false;
|
||||
for (auto& semaphore : semaphores_) {
|
||||
auto conn = semaphore->connection();
|
||||
if (conn->transport() == mscclpp::Transport::CudaIpc && !flushIpc) {
|
||||
auto& conn = semaphore->connection();
|
||||
if (conn.transport() == mscclpp::Transport::CudaIpc && !flushIpc) {
|
||||
// since all the cudaIpc channels are using the same cuda stream, we only need to flush one of them
|
||||
conn->flush();
|
||||
conn.flush();
|
||||
flushIpc = true;
|
||||
}
|
||||
if (mscclpp::AllIBTransports.has(conn->transport())) {
|
||||
conn->flush();
|
||||
if (mscclpp::AllIBTransports.has(conn.transport())) {
|
||||
conn.flush();
|
||||
}
|
||||
}
|
||||
return mscclpp::ProxyHandlerResult::Continue;
|
||||
@@ -758,7 +758,7 @@ void AllGatherTestEngine::setupConnections() {
|
||||
} else {
|
||||
auto service = std::dynamic_pointer_cast<AllGatherProxyService>(chanService_);
|
||||
setupMeshConnections(devPortChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0,
|
||||
[&](std::vector<std::shared_ptr<mscclpp::Connection>> conns,
|
||||
[&](std::vector<mscclpp::Connection> conns,
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>>& remoteMemories,
|
||||
const mscclpp::RegisteredMemory& localMemory) {
|
||||
std::vector<mscclpp::SemaphoreId> semaphoreIds;
|
||||
|
||||
@@ -364,14 +364,14 @@ std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createProxyService()
|
||||
}
|
||||
|
||||
void BaseTestEngine::setupMeshConnectionsInternal(
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>>& connections, mscclpp::RegisteredMemory& localRegMemory,
|
||||
std::vector<mscclpp::Connection>& connections, mscclpp::RegisteredMemory& localRegMemory,
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>>& remoteRegMemories, bool addConnections) {
|
||||
const int worldSize = args_.totalRanks;
|
||||
const int rank = args_.rank;
|
||||
const int nRanksPerNode = args_.nRanksPerNode;
|
||||
const int thisNode = rank / nRanksPerNode;
|
||||
const mscclpp::Transport ibTransport = IBs[args_.gpuNum];
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
|
||||
auto rankToNode = [&](int rank) { return rank / nRanksPerNode; };
|
||||
for (int r = 0; r < worldSize; r++) {
|
||||
@@ -393,7 +393,7 @@ void BaseTestEngine::setupMeshConnectionsInternal(
|
||||
remoteRegMemories.push_back(remoteMemory);
|
||||
}
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const std::shared_future<std::shared_ptr<mscclpp::Connection>>& future) { return future.get(); });
|
||||
[](const std::shared_future<mscclpp::Connection>& future) { return future.get(); });
|
||||
}
|
||||
|
||||
// Create mesh connections between all ranks. If recvBuff is nullptr, assume in-place.
|
||||
@@ -409,7 +409,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Port
|
||||
outputBufRegMem = comm_->registerMemory(outputBuff, outputBuffBytes, allTransports);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::vector<mscclpp::Connection> connections;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
mscclpp::RegisteredMemory& localRegMemory = (outputBuff) ? outputBufRegMem : inputBufRegMem;
|
||||
|
||||
@@ -441,7 +441,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
outputBufRegMem = comm_->registerMemory(outputBuff, outputBuffBytes, allTransports);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::vector<mscclpp::Connection> connections;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
mscclpp::RegisteredMemory& localRegMemory =
|
||||
(outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem;
|
||||
@@ -452,7 +452,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
|
||||
std::unordered_map<size_t, std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>> memorySemaphores;
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
|
||||
for (size_t i = 0; i < nChannelPerConnection; ++i) {
|
||||
memorySemaphores[cid].emplace_back(
|
||||
std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, connections[cid]));
|
||||
@@ -462,7 +462,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
|
||||
for (size_t i = 0; i < nChannelPerConnection; ++i) {
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
|
||||
memoryChannels.emplace_back(memorySemaphores[cid][i], remoteRegMemories[cid].get(),
|
||||
(outputBuff && semantic == ChannelSemantic::GET) ? outputBufRegMem : inputBufRegMem,
|
||||
outputBuff);
|
||||
@@ -492,7 +492,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
outputBufRegMem = comm_->registerMemory(outputBuff, outputBuffBytes, allTransports);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::vector<mscclpp::Connection> connections;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
mscclpp::RegisteredMemory& localRegMemory =
|
||||
(getPacketBuff) ? getPacketBufRegMem : ((outputBuff) ? outputBufRegMem : inputBufRegMem);
|
||||
@@ -513,7 +513,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
auto service = std::dynamic_pointer_cast<mscclpp::ProxyService>(chanService_);
|
||||
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
|
||||
memorySemaphores.emplace(cid, std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, connections[cid]));
|
||||
} else {
|
||||
connIdToSemId[cid] = service->buildAndAddSemaphore(*comm_, connections[cid]);
|
||||
@@ -521,7 +521,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
}
|
||||
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
|
||||
memoryChannels.emplace_back(memorySemaphores[cid],
|
||||
(outputBuff) ? remoteRegMemoriesOutput[cid].get() : remoteRegMemories[cid].get(),
|
||||
inputBufRegMem, (outputBuff) ? outputBufRegMem.data() : nullptr);
|
||||
|
||||
@@ -102,15 +102,15 @@ class BaseTestEngine {
|
||||
|
||||
double benchTime();
|
||||
|
||||
void setupMeshConnectionsInternal(std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
void setupMeshConnectionsInternal(std::vector<mscclpp::Connection>& connections,
|
||||
mscclpp::RegisteredMemory& localMemory,
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>>& remoteRegMemories,
|
||||
bool addConnections = true);
|
||||
|
||||
protected:
|
||||
using SetupChannelFunc = std::function<void(std::vector<std::shared_ptr<mscclpp::Connection>>,
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>>&,
|
||||
const mscclpp::RegisteredMemory&)>;
|
||||
using SetupChannelFunc =
|
||||
std::function<void(std::vector<mscclpp::Connection>, std::vector<std::shared_future<mscclpp::RegisteredMemory>>&,
|
||||
const mscclpp::RegisteredMemory&)>;
|
||||
template <class T>
|
||||
using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
void setupMeshConnections(std::vector<DeviceHandle<mscclpp::PortChannel>>& portChannels, void* inputBuff,
|
||||
|
||||
Reference in New Issue
Block a user