connect() APIs changed to return an instance instead of a shared_ptr (#680)

The key purpose is handling all mscclpp objects' memory internally by
hiding shared pointers from user APIs.
* `Connection` class is now a wrapper of `BaseConnection` class that is
equivalent to the previous `Connection` class
* `connect()` methods now return `Connection` instead of
`std::shared_ptr<Connection>`
* Removed `connectOnSetup()` method
This commit is contained in:
Changho Hwang
2025-11-15 11:40:40 -08:00
committed by GitHub
parent 7eb3ff701a
commit 1bf4e8c90e
31 changed files with 252 additions and 213 deletions

View File

@@ -216,7 +216,7 @@ class AllgatherAlgo6 : public mscclpp::AlgorithmBuilder {
private:
bool disableChannelCache_;
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::vector<mscclpp::Connection> conns_;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores_;
const int nChannelsPerConnection_ = 35;
@@ -235,7 +235,7 @@ class AllgatherAlgo8 : public mscclpp::AlgorithmBuilder {
mscclpp::Algorithm build() override;
private:
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::vector<mscclpp::Connection> conns_;
void initialize(std::shared_ptr<mscclpp::Communicator> comm,
std::unordered_map<std::string, std::shared_ptr<void>>& extras);

View File

@@ -1137,7 +1137,7 @@ class AllreducePacket : public mscclpp::AlgorithmBuilder {
size_t scratchBufferSize_;
std::shared_ptr<char> scratchBuffer_;
const int nSegmentsForScratchBuffer_ = 2;
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::vector<mscclpp::Connection> conns_;
std::shared_ptr<uint32_t> deviceFlag7_;
std::shared_ptr<uint32_t> deviceFlag28_;
@@ -1164,7 +1164,7 @@ class AllreduceNvls : public mscclpp::AlgorithmBuilder {
uint32_t nSwitchChannels_;
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>> memoryChannelsDeviceHandle_;
std::vector<mscclpp::BaseMemoryChannel> baseChannels_;
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::vector<mscclpp::Connection> conns_;
};
class AllreduceNvlsWithCopy : public mscclpp::AlgorithmBuilder {
@@ -1188,7 +1188,7 @@ class AllreduceNvlsWithCopy : public mscclpp::AlgorithmBuilder {
uint32_t nSwitchChannels_;
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>> memoryChannelsDeviceHandle_;
std::vector<mscclpp::BaseMemoryChannel> baseChannels_;
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::vector<mscclpp::Connection> conns_;
};
class Allreduce8 : public mscclpp::AlgorithmBuilder {
@@ -1209,7 +1209,7 @@ class Allreduce8 : public mscclpp::AlgorithmBuilder {
size_t scratchBufferSize_;
std::shared_ptr<mscclpp::Communicator> comm_;
int nChannelsPerConnection_;
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::vector<mscclpp::Connection> conns_;
std::shared_ptr<char> scratchBuffer_;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> outputSemaphores_;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> inputScratchSemaphores_;

View File

@@ -171,7 +171,7 @@ class BroadcastAlgo6 : public mscclpp::AlgorithmBuilder {
void* output, size_t, ncclDataType_t);
mscclpp::AlgorithmCtxKey generateBroadcastContextKey(const void*, void*, size_t, ncclDataType_t);
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::vector<mscclpp::Connection> conns_;
size_t scratchMemSize_;
std::shared_ptr<char> scratchBuffer_;
};

View File

@@ -20,7 +20,7 @@ std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<msccl
}
std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
const std::vector<mscclpp::Connection>& connections,
const std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores,
const std::vector<mscclpp::RegisteredMemory>& remoteMemories, mscclpp::RegisteredMemory localMemory,
int nChannelsPerConnection) {
@@ -28,7 +28,7 @@ std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
size_t nConnections = connections.size();
for (int idx = 0; idx < nChannelsPerConnection; ++idx) {
for (size_t cid = 0; cid < nConnections; ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
channels.emplace_back(memorySemaphores[idx * nConnections + cid], remoteMemories[cid], localMemory, nullptr);
}
}
@@ -36,25 +36,25 @@ std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
return channels;
}
std::vector<std::shared_ptr<mscclpp::Connection>> setupConnections(std::shared_ptr<mscclpp::Communicator> comm) {
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
std::vector<mscclpp::Connection> setupConnections(std::shared_ptr<mscclpp::Communicator> comm) {
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
for (int i = 0; i < comm->bootstrap()->getNranks(); i++) {
if (i == comm->bootstrap()->getRank()) continue;
connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i));
}
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<mscclpp::Connection> connections;
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
[](const auto& future) { return future.get(); });
return connections;
}
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<mscclpp::Connection>& connections,
int nChannelsPerConnection) {
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
for (int idx = 0; idx < nChannelsPerConnection; ++idx) {
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
memorySemaphores.emplace_back(
std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*(comm), connections[cid]));
}
@@ -117,14 +117,14 @@ std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SwitchChannel>> setupNvlsChannelD
}
std::vector<mscclpp::BaseMemoryChannel> setupBaseMemoryChannels(
const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
const std::vector<mscclpp::Connection>& connections,
const std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores,
int nChannelsPerConnection) {
std::vector<mscclpp::BaseMemoryChannel> channels;
size_t nConnections = connections.size();
for (int idx = 0; idx < nChannelsPerConnection; ++idx) {
for (size_t cid = 0; cid < nConnections; ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
channels.emplace_back(memorySemaphores[idx * nConnections + cid]);
}
}

View File

@@ -33,15 +33,15 @@ std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<msccl
mscclpp::RegisteredMemory localMemory);
std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
const std::vector<mscclpp::Connection>& connections,
const std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores,
const std::vector<mscclpp::RegisteredMemory>& remoteMemories, mscclpp::RegisteredMemory localMemory,
int nChannelsPerConnection);
std::vector<std::shared_ptr<mscclpp::Connection>> setupConnections(std::shared_ptr<mscclpp::Communicator> comm);
std::vector<mscclpp::Connection> setupConnections(std::shared_ptr<mscclpp::Communicator> comm);
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<mscclpp::Connection>& connections,
int nChannelsPerConnection);
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> setupMemoryChannelDeviceHandles(
@@ -57,7 +57,7 @@ std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SwitchChannel>> setupNvlsChannelD
const std::vector<mscclpp::SwitchChannel>& nvlsChannels);
std::vector<mscclpp::BaseMemoryChannel> setupBaseMemoryChannels(
const std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
const std::vector<mscclpp::Connection>& connections,
const std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores,
int nChannelsPerConnection);

View File

@@ -83,8 +83,8 @@ The connection is created by calling `connect` on the context object:
```cpp
// From gpu_ping_pong.cu, lines 76 and 82
std::shared_ptr<mscclpp::Connection> conn0 = ctx->connect(/*localEndpoint*/ ep0, /*remoteEndpoint*/ ep1);
std::shared_ptr<mscclpp::Connection> conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
mscclpp::Connection conn0 = ctx->connect(/*localEndpoint*/ ep0, /*remoteEndpoint*/ ep1);
mscclpp::Connection conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
```
The `localEndpoint` and `remoteEndpoint` parameters specify which endpoints are used for the connection. A connection is asymmetric by nature, meaning that we need to create one connection for each endpoint. In this case, `conn0` is created for `ep0` to communicate with `ep1`, and `conn1` is created for `ep1` to communicate with `ep0`.
@@ -101,7 +101,7 @@ sendToProcessB(serializedEp0); // send serializedEp0 to Process B using any IPC
mscclpp::Endpoint ep1 = ctx->createEndpoint({transport, {mscclpp::DeviceType::GPU, 1}});
std::vector<char> serializedEp0 = recvFromProcessA(); // receive serializedEp0 from Process A
mscclpp::Endpoint ep0 = mscclpp::Endpoint::deserialize(serializedEp0);
std::shared_ptr<mscclpp::Connection> conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
mscclpp::Connection conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
```
## SemaphoreStub and Semaphore

View File

@@ -107,18 +107,18 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
}
private:
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::vector<mscclpp::Connection> conns_;
std::shared_ptr<mscclpp::ProxyService> proxyService_;
int worldSize_;
void initialize(std::shared_ptr<mscclpp::Communicator> comm) {
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
worldSize_ = comm->bootstrap()->getNranks();
for (int i = 0; i < worldSize_; i++) {
if (i == comm->bootstrap()->getRank()) continue;
connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i));
}
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<mscclpp::Connection> connections;
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
[](const auto& future) { return future.get(); });
this->conns_ = std::move(connections);

View File

@@ -73,13 +73,13 @@ int main() {
log("GPU 0: Creating a connection and a semaphore stub ...");
MSCCLPP_CUDATHROW(cudaSetDevice(0));
std::shared_ptr<mscclpp::Connection> conn0 = ctx->connect(/*localEndpoint*/ ep0, /*remoteEndpoint*/ ep1);
mscclpp::Connection conn0 = ctx->connect(/*localEndpoint*/ ep0, /*remoteEndpoint*/ ep1);
mscclpp::SemaphoreStub semaStub0(conn0);
log("GPU 1: Creating a connection and a semaphore stub ...");
MSCCLPP_CUDATHROW(cudaSetDevice(1));
std::shared_ptr<mscclpp::Connection> conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
mscclpp::Connection conn1 = ctx->connect(/*localEndpoint*/ ep1, /*remoteEndpoint*/ ep0);
mscclpp::SemaphoreStub semaStub1(conn1);
log("GPU 0: Creating a semaphore and a memory channel ...");

View File

@@ -425,6 +425,7 @@ struct EndpointConfig {
class Context;
class Connection;
class BaseConnection;
class RegisteredMemory;
class SemaphoreStub;
class Semaphore;
@@ -474,7 +475,7 @@ class Endpoint {
std::shared_ptr<Impl> pimpl_;
friend class Context;
friend class Connection;
friend class BaseConnection;
};
/// Context for communication. This provides a low-level interface for forming connections in use-cases
@@ -521,8 +522,8 @@ class Context : public std::enable_shared_from_this<Context> {
///
/// @param localEndpoint The local endpoint.
/// @param remoteEndpoint The remote endpoint.
/// @return A shared pointer to the connection.
std::shared_ptr<Connection> connect(const Endpoint& localEndpoint, const Endpoint& remoteEndpoint);
/// @return A connection object.
Connection connect(const Endpoint& localEndpoint, const Endpoint& remoteEndpoint);
private:
Context();
@@ -531,7 +532,7 @@ class Context : public std::enable_shared_from_this<Context> {
std::unique_ptr<Impl> pimpl_;
friend class Endpoint;
friend class Connection;
friend class BaseConnection;
friend class RegisteredMemory;
friend class SemaphoreStub;
};
@@ -578,7 +579,7 @@ class RegisteredMemory {
std::shared_ptr<Impl> pimpl_;
friend class Context;
friend class Connection;
friend class BaseConnection;
friend class SemaphoreStub;
friend class Semaphore;
};
@@ -587,12 +588,7 @@ class RegisteredMemory {
class Connection {
public:
/// Constructor.
/// @param context The context associated with the connection.
/// @param localEndpoint The local endpoint of the connection.
Connection(std::shared_ptr<Context> context, const Endpoint& localEndpoint);
/// Destructor.
virtual ~Connection() = default;
Connection() = default;
/// Write data from a source RegisteredMemory to a destination RegisteredMemory.
///
@@ -601,8 +597,7 @@ class Connection {
/// @param src The source RegisteredMemory.
/// @param srcOffset The offset in bytes from the start of the source RegisteredMemory.
/// @param size The number of bytes to write.
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) = 0;
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size);
/// Update an 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process.
///
@@ -610,19 +605,19 @@ class Connection {
/// @param dstOffset The offset in bytes from the start of the destination RegisteredMemory.
/// @param src A pointer to the value to update.
/// @param newValue The new value to write.
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue);
/// Flush any pending writes to the remote process.
/// @param timeoutUsec Timeout in microseconds. Default: -1 (no timeout)
virtual void flush(int64_t timeoutUsec = -1) = 0;
void flush(int64_t timeoutUsec = -1);
/// Get the transport used by the local process.
/// @return The transport used by the local process.
virtual Transport transport() const = 0;
Transport transport() const;
/// Get the transport used by the remote process.
/// @return The transport used by the remote process.
virtual Transport remoteTransport() const = 0;
Transport remoteTransport() const;
/// Get the context associated with this connection.
/// @return A shared pointer to the context associated with this connection.
@@ -636,22 +631,23 @@ class Connection {
/// @return The maximum number of write requests that can be queued.
int getMaxWriteQueueSize() const;
protected:
static const Endpoint::Impl& getImpl(const Endpoint& endpoint);
static const RegisteredMemory::Impl& getImpl(const RegisteredMemory& memory);
static Context::Impl& getImpl(Context& context);
private:
Connection(std::shared_ptr<BaseConnection> impl);
std::shared_ptr<BaseConnection> impl_;
std::shared_ptr<Context> context_;
Endpoint localEndpoint_;
int maxWriteQueueSize_;
friend class Context;
friend class Communicator;
friend class SemaphoreStub;
friend class Semaphore;
friend class ProxyService;
};
/// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
class SemaphoreStub {
public:
/// Constructor.
/// @param connection A shared pointer to the connection associated with this semaphore.
SemaphoreStub(std::shared_ptr<Connection> connection);
/// @param connection The connection associated with this semaphore.
SemaphoreStub(const Connection& connection);
/// Get the memory associated with this semaphore.
/// @return A reference to the registered memory for this semaphore.
@@ -686,8 +682,8 @@ class Semaphore {
Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub);
/// Get the connection associated with this semaphore.
/// @return A shared pointer to the connection.
std::shared_ptr<Connection> connection() const;
/// @return The connection.
Connection& connection();
/// Get the local memory associated with this semaphore.
/// @return A reference to the local registered memory.
@@ -873,34 +869,23 @@ class Communicator {
/// @param localEndpoint The local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @return A future of shared pointer to the connection.
/// @return A future of the connection.
///
std::shared_future<std::shared_ptr<Connection>> connect(const Endpoint& localEndpoint, int remoteRank, int tag = 0);
std::shared_future<Connection> connect(const Endpoint& localEndpoint, int remoteRank, int tag = 0);
/// Connect to a remote rank. Wrapper of `connect(localEndpoint, remoteRank, tag)`.
/// @param localConfig The configuration for the local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @return A future of shared pointer to the connection.
std::shared_future<std::shared_ptr<Connection>> connect(const EndpointConfig& localConfig, int remoteRank,
int tag = 0);
[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
shared_future<std::shared_ptr<Connection>>
connect(int remoteRank, int tag, EndpointConfig localConfig);
[[deprecated("Use connect() instead. This will be removed in a future release.")]] NonblockingFuture<
std::shared_ptr<Connection>>
connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig) {
return connect(localConfig, remoteRank, tag);
}
/// @return A future of the connection.
std::shared_future<Connection> connect(const EndpointConfig& localConfig, int remoteRank, int tag = 0);
/// Build a semaphore for cross-process synchronization.
/// @param connection The connection associated with this semaphore.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the operation.
/// @return A future of the built semaphore.
std::shared_future<Semaphore> buildSemaphore(std::shared_ptr<Connection> connection, int remoteRank, int tag = 0);
std::shared_future<Semaphore> buildSemaphore(const Connection& connection, int remoteRank, int tag = 0);
/// Get the remote rank a connection is connected to.
///

View File

@@ -33,7 +33,7 @@ class ProxyService : public BaseProxyService {
/// Build and add a semaphore to the proxy service.
/// @param connection The connection associated with the semaphore.
/// @return The ID of the semaphore.
SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
SemaphoreId buildAndAddSemaphore(Communicator& communicator, const Connection& connection);
/// Add a semaphore to the proxy service.
/// @param semaphore The semaphore to be added
@@ -83,7 +83,7 @@ class ProxyService : public BaseProxyService {
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> memories_;
std::shared_ptr<Proxy> proxy_;
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests_;
std::unordered_map<std::shared_ptr<BaseConnection>, int> inflightRequests_;
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw);
};

View File

@@ -27,11 +27,11 @@ class Host2DeviceSemaphore {
/// Constructor.
/// @param communicator The communicator.
/// @param connection The connection associated with this semaphore.
Host2DeviceSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
Host2DeviceSemaphore(Communicator& communicator, const Connection& connection);
/// Returns the connection.
/// @return The connection associated with this semaphore.
std::shared_ptr<Connection> connection() const;
Connection& connection();
/// Signal the device.
void signal();
@@ -59,11 +59,11 @@ class Host2HostSemaphore {
/// @param communicator The communicator.
/// @param connection The connection associated with this semaphore. Transport::CudaIpc is not allowed for
/// Host2HostSemaphore.
Host2HostSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
Host2HostSemaphore(Communicator& communicator, const Connection& connection);
/// Returns the connection.
/// @return The connection associated with this semaphore.
std::shared_ptr<Connection> connection() const;
Connection& connection();
/// Signal the remote host.
void signal();
@@ -92,11 +92,11 @@ class MemoryDevice2DeviceSemaphore {
/// Constructor.
/// @param communicator The communicator.
/// @param connection The connection associated with this semaphore.
MemoryDevice2DeviceSemaphore(Communicator& communicator, std::shared_ptr<Connection> connection);
MemoryDevice2DeviceSemaphore(Communicator& communicator, const Connection& connection);
/// Returns the connection.
/// @return The connection associated with this semaphore.
std::shared_ptr<Connection> connection() const;
Connection& connection();
/// Device-side handle for MemoryDevice2DeviceSemaphore.
using DeviceHandle = MemoryDevice2DeviceSemaphoreDeviceHandle;

View File

@@ -202,7 +202,7 @@ void register_core(nb::module_& m) {
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
.def(nb::init<std::shared_ptr<Connection>>(), nb::arg("connection"))
.def(nb::init<const Connection&>(), nb::arg("connection"))
.def("memory", &SemaphoreStub::memory)
.def("serialize", &SemaphoreStub::serialize)
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
@@ -215,7 +215,7 @@ void register_core(nb::module_& m) {
.def("remote_memory", &Semaphore::remoteMemory);
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
def_shared_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
def_shared_future<Connection>(m, "Connection");
nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
@@ -231,8 +231,8 @@ void register_core(nb::module_& m) {
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag") = 0)
.def("recv_memory", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag") = 0)
.def("connect",
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const EndpointConfig&, int,
int)>(&Communicator::connect),
static_cast<std::shared_future<Connection> (Communicator::*)(const EndpointConfig&, int, int)>(
&Communicator::connect),
nb::arg("local_config"), nb::arg("remote_rank"), nb::arg("tag") = 0)
.def(
"connect_on_setup",

View File

@@ -12,7 +12,7 @@ using namespace mscclpp;
void register_semaphore(nb::module_& m) {
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &Host2DeviceSemaphore::connection)
.def("signal", &Host2DeviceSemaphore::signal)
.def("device_handle", &Host2DeviceSemaphore::deviceHandle);
@@ -27,7 +27,7 @@ void register_semaphore(nb::module_& m) {
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &Host2HostSemaphore::connection)
.def("signal", &Host2HostSemaphore::signal)
.def("poll", &Host2HostSemaphore::poll)
@@ -36,7 +36,7 @@ void register_semaphore(nb::module_& m) {
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
.def("device_handle", &MemoryDevice2DeviceSemaphore::deviceHandle);

View File

@@ -21,13 +21,13 @@ class MyProxyService {
private:
int deviceNumaNode_;
int my_rank_, nranks_, dataSize_;
std::vector<std::shared_ptr<mscclpp::Connection>> connections_;
std::vector<mscclpp::Connection> connections_;
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
mscclpp::Proxy proxy_;
public:
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<std::shared_ptr<mscclpp::Connection>> conns,
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<mscclpp::Connection> conns,
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem,
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores)
: my_rank_(my_rank),
@@ -46,10 +46,10 @@ class MyProxyService {
int dataSizePerRank = dataSize_ / nranks_;
for (int r = 1; r < nranks_; ++r) {
int nghr = (my_rank_ + r) % nranks_;
connections_[nghr]->write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
connections_[nghr].write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
semaphores_[nghr]->signal();
connections_[nghr]->flush();
connections_[nghr].flush();
}
return mscclpp::ProxyHandlerResult::Continue;
}
@@ -63,7 +63,7 @@ class MyProxyService {
void init_mscclpp_proxy_test_module(nb::module_ &m) {
nb::class_<MyProxyService>(m, "MyProxyService")
.def(nb::init<int, int, int, std::vector<std::shared_ptr<mscclpp::Connection>>,
.def(nb::init<int, int, int, std::vector<mscclpp::Connection>,
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>>,
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>>>(),
nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"),

View File

@@ -99,16 +99,16 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
return shared_future;
}
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const Endpoint& localEndpoint,
int remoteRank, int tag) {
MSCCLPP_API_CPP std::shared_future<Connection> Communicator::connect(const Endpoint& localEndpoint, int remoteRank,
int tag) {
if (remoteRank == bootstrap()->getRank()) {
// Connection to self
auto remoteEndpoint = context()->createEndpoint(localEndpoint.config());
auto connection = context()->connect(localEndpoint, remoteEndpoint);
std::promise<std::shared_ptr<Connection>> promise;
std::promise<Connection> promise;
promise.set_value(connection);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return std::shared_future<std::shared_ptr<Connection>>(promise.get_future());
pimpl_->connectionInfos_[connection.impl_.get()] = {remoteRank, tag};
return std::shared_future<Connection>(promise.get_future());
}
bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);
@@ -123,27 +123,22 @@ MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::co
bootstrap()->recv(data, remoteRank, tag);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = context()->connect(localEndpoint, remoteEndpoint);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
pimpl_->connectionInfos_[connection.impl_.get()] = {remoteRank, tag};
return connection;
});
auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move(future));
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<std::shared_ptr<Connection>>>(shared_future));
auto shared_future = std::shared_future<Connection>(std::move(future));
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<Connection>>(shared_future));
return shared_future;
}
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const EndpointConfig& localConfig,
int remoteRank, int tag) {
MSCCLPP_API_CPP std::shared_future<Connection> Communicator::connect(const EndpointConfig& localConfig, int remoteRank,
int tag) {
auto localEndpoint = context()->createEndpoint(localConfig);
return connect(localEndpoint, remoteRank, tag);
}
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
EndpointConfig localConfig) {
return connect(localConfig, remoteRank, tag);
}
MSCCLPP_API_CPP std::shared_future<Semaphore> Communicator::buildSemaphore(std::shared_ptr<Connection> connection,
int remoteRank, int tag) {
MSCCLPP_API_CPP std::shared_future<Semaphore> Communicator::buildSemaphore(const Connection& connection, int remoteRank,
int tag) {
SemaphoreStub localStub(connection);
bootstrap()->send(localStub.serialize(), remoteRank, tag);
@@ -165,11 +160,11 @@ MSCCLPP_API_CPP std::shared_future<Semaphore> Communicator::buildSemaphore(std::
}
MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
return pimpl_->connectionInfos_.at(&connection).remoteRank;
return pimpl_->connectionInfos_.at(connection.impl_.get()).remoteRank;
}
MSCCLPP_API_CPP int Communicator::tagOf(const Connection& connection) {
return pimpl_->connectionInfos_.at(&connection).tag;
return pimpl_->connectionInfos_.at(connection.impl_.get()).tag;
}
} // namespace mscclpp

View File

@@ -32,28 +32,54 @@ static bool isSameProcess(const Endpoint& a, const Endpoint& b) {
return a.hostHash() == b.hostHash() && a.pidHash() == b.pidHash();
}
// Connection
// BaseConnection
const Endpoint::Impl& Connection::getImpl(const Endpoint& endpoint) { return *(endpoint.pimpl_); }
const Endpoint::Impl& BaseConnection::getImpl(const Endpoint& endpoint) { return *(endpoint.pimpl_); }
const RegisteredMemory::Impl& Connection::getImpl(const RegisteredMemory& memory) { return *(memory.pimpl_); }
const RegisteredMemory::Impl& BaseConnection::getImpl(const RegisteredMemory& memory) { return *(memory.pimpl_); }
Context::Impl& Connection::getImpl(Context& context) { return *(context.pimpl_); }
Context::Impl& BaseConnection::getImpl(Context& context) { return *(context.pimpl_); }
MSCCLPP_API_CPP Connection::Connection(std::shared_ptr<Context> context, const Endpoint& localEndpoint)
MSCCLPP_API_CPP BaseConnection::BaseConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint)
: context_(context), localEndpoint_(localEndpoint), maxWriteQueueSize_(localEndpoint.maxWriteQueueSize()) {}
MSCCLPP_API_CPP std::shared_ptr<Context> Connection::context() const { return context_; }
MSCCLPP_API_CPP std::shared_ptr<Context> BaseConnection::context() const { return context_; }
MSCCLPP_API_CPP const Device& Connection::localDevice() const { return localEndpoint_.device(); }
MSCCLPP_API_CPP const Device& BaseConnection::localDevice() const { return localEndpoint_.device(); }
MSCCLPP_API_CPP int Connection::getMaxWriteQueueSize() const { return maxWriteQueueSize_; }
MSCCLPP_API_CPP int BaseConnection::getMaxWriteQueueSize() const { return maxWriteQueueSize_; }
// Connection wrapper
Connection::Connection(std::shared_ptr<BaseConnection> impl) : impl_(impl) {}
MSCCLPP_API_CPP void Connection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src,
uint64_t srcOffset, uint64_t size) {
impl_->write(dst, dstOffset, src, srcOffset, size);
}
MSCCLPP_API_CPP void Connection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src,
uint64_t newValue) {
impl_->updateAndSync(dst, dstOffset, src, newValue);
}
MSCCLPP_API_CPP void Connection::flush(int64_t timeoutUsec) { impl_->flush(timeoutUsec); }
MSCCLPP_API_CPP Transport Connection::transport() const { return impl_->transport(); }
MSCCLPP_API_CPP Transport Connection::remoteTransport() const { return impl_->remoteTransport(); }
MSCCLPP_API_CPP std::shared_ptr<Context> Connection::context() const { return impl_->context(); }
MSCCLPP_API_CPP const Device& Connection::localDevice() const { return impl_->localDevice(); }
MSCCLPP_API_CPP int Connection::getMaxWriteQueueSize() const { return impl_->getMaxWriteQueueSize(); }
// CudaIpcConnection
CudaIpcConnection::CudaIpcConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint,
const Endpoint& remoteEndpoint)
: Connection(context, localEndpoint) {
: BaseConnection(context, localEndpoint) {
if (localEndpoint.transport() != Transport::CudaIpc || remoteEndpoint.transport() != Transport::CudaIpc) {
THROW(CONN, Error, ErrorCode::InternalError, "CudaIpc transport is required for CudaIpcConnection");
}
@@ -163,7 +189,7 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
IBConnection::IBConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint,
const Endpoint& remoteEndpoint)
: Connection(context, localEndpoint),
: BaseConnection(context, localEndpoint),
transport_(localEndpoint.transport()),
remoteTransport_(remoteEndpoint.transport()),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
@@ -276,7 +302,7 @@ void IBConnection::flush(int64_t timeoutUsec) {
EthernetConnection::EthernetConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint,
const Endpoint& remoteEndpoint, uint64_t sendBufferSize, uint64_t recvBufferSize)
: Connection(context, localEndpoint),
: BaseConnection(context, localEndpoint),
abortFlag_(0),
sendBufferSize_(sendBufferSize),
recvBufferSize_(recvBufferSize) {

View File

@@ -76,8 +76,7 @@ MSCCLPP_API_CPP Endpoint Context::createEndpoint(EndpointConfig config) {
return Endpoint(std::make_shared<Endpoint::Impl>(config, *pimpl_));
}
MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &localEndpoint,
const Endpoint &remoteEndpoint) {
MSCCLPP_API_CPP Connection Context::connect(const Endpoint &localEndpoint, const Endpoint &remoteEndpoint) {
if (localEndpoint.device().type == DeviceType::GPU && localEndpoint.device().id < 0) {
throw Error("No GPU device ID provided for local endpoint", ErrorCode::InvalidUsage);
}
@@ -93,7 +92,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &loc
<< std::to_string(remoteEndpoint.transport()) << ") endpoints";
throw Error(ss.str(), ErrorCode::InvalidUsage);
}
std::shared_ptr<Connection> conn;
std::shared_ptr<BaseConnection> conn;
if (localTransport == Transport::CudaIpc) {
conn = std::make_shared<CudaIpcConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
} else if (AllIBTransports.has(localTransport)) {
@@ -103,7 +102,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &loc
} else {
throw Error("Unsupported transport", ErrorCode::InternalError);
}
return conn;
return Connection(conn);
}
} // namespace mscclpp

View File

@@ -112,7 +112,7 @@ namespace mscclpp {
struct ExecutionContext {
std::shared_ptr<ProxyService> proxyService;
std::unordered_map<int, std::shared_ptr<Connection>> connections;
std::unordered_map<int, Connection> connections;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
MemoryId localMemoryIdBegin = MemoryId(0);
@@ -270,7 +270,7 @@ struct Executor::Impl {
};
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers();
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
for (int peer : connectedPeers) {
Transport transport =
inSameNode(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
@@ -339,10 +339,10 @@ struct Executor::Impl {
auto connection = context.connections.at(peer);
if (info.channelType == ChannelType::MEMORY) {
futureMemorySemaphores.push_back(this->comm->buildSemaphore(
connection, this->comm->remoteRankOf(*connection), this->comm->tagOf(*connection)));
connection, this->comm->remoteRankOf(connection), this->comm->tagOf(connection)));
} else if (info.channelType == ChannelType::PORT) {
futureProxySemaphores.push_back(this->comm->buildSemaphore(
connection, this->comm->remoteRankOf(*connection), this->comm->tagOf(*connection)));
futureProxySemaphores.push_back(this->comm->buildSemaphore(connection, this->comm->remoteRankOf(connection),
this->comm->tagOf(connection)));
}
}
}

View File

@@ -59,7 +59,7 @@ struct ConnectionInfo {
struct Communicator::Impl {
std::shared_ptr<Bootstrap> bootstrap_;
std::shared_ptr<Context> context_;
std::unordered_map<const Connection*, ConnectionInfo> connectionInfos_;
std::unordered_map<const BaseConnection*, ConnectionInfo> connectionInfos_;
// Temporary storage for the latest RecvItem of each {remoteRank, tag} pair.
// If the RecvItem gets ready, it will be removed at the next call to getLastRecvItem.

View File

@@ -15,7 +15,46 @@
namespace mscclpp {
class CudaIpcConnection : public Connection {
/// Internal base class for connection implementations between two processes.
class BaseConnection {
public:
BaseConnection(std::shared_ptr<Context> context, const Endpoint& localEndpoint);
virtual ~BaseConnection() = default;
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) = 0;
virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0;
virtual void flush(int64_t timeoutUsec = -1) = 0;
virtual Transport transport() const = 0;
virtual Transport remoteTransport() const = 0;
std::shared_ptr<Context> context() const;
const Device& localDevice() const;
int getMaxWriteQueueSize() const;
protected:
friend class Context;
friend class CudaIpcConnection;
friend class IBConnection;
friend class EthernetConnection;
static const Endpoint::Impl& getImpl(const Endpoint& endpoint);
static const RegisteredMemory::Impl& getImpl(const RegisteredMemory& memory);
static Context::Impl& getImpl(Context& context);
std::shared_ptr<Context> context_;
Endpoint localEndpoint_;
int maxWriteQueueSize_;
};
class CudaIpcConnection : public BaseConnection {
private:
std::shared_ptr<CudaIpcStream> stream_;
@@ -33,7 +72,7 @@ class CudaIpcConnection : public Connection {
void flush(int64_t timeoutUsec) override;
};
class IBConnection : public Connection {
class IBConnection : public BaseConnection {
private:
Transport transport_;
Transport remoteTransport_;
@@ -56,7 +95,7 @@ class IBConnection : public Connection {
void flush(int64_t timeoutUsec) override;
};
class EthernetConnection : public Connection {
class EthernetConnection : public BaseConnection {
private:
std::unique_ptr<Socket> sendSocket_;
std::unique_ptr<Socket> recvSocket_;

View File

@@ -42,7 +42,7 @@ MSCCLPP_API_CPP ProxyService::ProxyService(int fifoSize) {
}
MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection) {
const Connection& connection) {
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator, connection));
return semaphores_.size() - 1;
}
@@ -89,13 +89,14 @@ MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger.fields.semaphoreId];
int maxWriteQueueSize = semaphore->connection()->getMaxWriteQueueSize();
auto& numRequests = inflightRequests_[semaphore->connection()];
auto& conn = semaphore->connection();
int maxWriteQueueSize = conn.getMaxWriteQueueSize();
auto& numRequests = inflightRequests_[conn.impl_];
if (trigger.fields.type & TriggerData) {
RegisteredMemory& dst = memories_[trigger.fields.dstMemoryId];
RegisteredMemory& src = memories_[trigger.fields.srcMemoryId];
semaphore->connection()->write(dst, trigger.fields.dstOffset, src, trigger.fields.srcOffset, trigger.fields.size);
conn.write(dst, trigger.fields.dstOffset, src, trigger.fields.srcOffset, trigger.fields.size);
numRequests++;
}
@@ -106,7 +107,7 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger trigger) {
if (((trigger.fields.type & TriggerSync) && numRequests > 0) ||
(maxWriteQueueSize != -1 && numRequests > maxWriteQueueSize)) {
semaphore->connection()->flush();
conn.flush();
numRequests = 0;
}

View File

@@ -6,6 +6,7 @@
#include "api.h"
#include "atomic.hpp"
#include "connection.hpp"
#include "context.hpp"
#include "debug.h"
#include "registered_memory.hpp"
@@ -14,7 +15,7 @@
namespace mscclpp {
struct SemaphoreStub::Impl {
Impl(std::shared_ptr<Connection> connection);
Impl(const Connection& connection);
Impl(const RegisteredMemory& idMemory, const Device& device);
@@ -22,7 +23,7 @@ struct SemaphoreStub::Impl {
std::shared_ptr<uint64_t> gpuCallocToken(std::shared_ptr<Context> context);
std::shared_ptr<Connection> connection_;
Connection connection_;
std::shared_ptr<uint64_t> token_;
RegisteredMemory idMemory_;
Device device_;
@@ -41,9 +42,9 @@ std::shared_ptr<uint64_t> SemaphoreStub::Impl::gpuCallocToken(std::shared_ptr<Co
#endif // !defined(MSCCLPP_DEVICE_HIP)
}
SemaphoreStub::Impl::Impl(std::shared_ptr<Connection> connection) : connection_(connection) {
SemaphoreStub::Impl::Impl(const Connection& connection) : connection_(connection) {
// Allocate a semaphore ID on the local device
const Device& localDevice = connection_->localDevice();
const Device& localDevice = connection_.localDevice();
if (localDevice.type == DeviceType::CPU) {
token_ = std::make_shared<uint64_t>(0);
} else if (localDevice.type == DeviceType::GPU) {
@@ -51,12 +52,11 @@ SemaphoreStub::Impl::Impl(std::shared_ptr<Connection> connection) : connection_(
throw Error("Local GPU ID is not provided", ErrorCode::InvalidUsage);
}
MSCCLPP_CUDATHROW(cudaSetDevice(localDevice.id));
token_ = gpuCallocToken(connection_->context());
token_ = gpuCallocToken(connection_.context());
} else {
throw Error("Unsupported local device type", ErrorCode::InvalidUsage);
}
idMemory_ =
std::move(connection->context()->registerMemory(token_.get(), sizeof(uint64_t), connection_->transport()));
idMemory_ = std::move(connection_.context()->registerMemory(token_.get(), sizeof(uint64_t), connection_.transport()));
}
SemaphoreStub::Impl::Impl(const RegisteredMemory& idMemory, const Device& device)
@@ -64,7 +64,7 @@ SemaphoreStub::Impl::Impl(const RegisteredMemory& idMemory, const Device& device
SemaphoreStub::SemaphoreStub(std::shared_ptr<Impl> pimpl) : pimpl_(std::move(pimpl)) {}
MSCCLPP_API_CPP SemaphoreStub::SemaphoreStub(std::shared_ptr<Connection> connection)
MSCCLPP_API_CPP SemaphoreStub::SemaphoreStub(const Connection& connection)
: pimpl_(std::make_shared<Impl>(connection)) {}
MSCCLPP_API_CPP std::vector<char> SemaphoreStub::serialize() const {
@@ -103,17 +103,15 @@ Semaphore::Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remote
}
}
MSCCLPP_API_CPP std::shared_ptr<Connection> Semaphore::connection() const {
return pimpl_->localStub_.pimpl_->connection_;
}
MSCCLPP_API_CPP Connection& Semaphore::connection() { return pimpl_->localStub_.pimpl_->connection_; }
MSCCLPP_API_CPP const RegisteredMemory& Semaphore::localMemory() const { return pimpl_->localStub_.memory(); }
MSCCLPP_API_CPP const RegisteredMemory& Semaphore::remoteMemory() const { return pimpl_->remoteStubMemory_; }
static Semaphore buildSemaphoreFromConnection(Communicator& communicator, std::shared_ptr<Connection> connection) {
static Semaphore buildSemaphoreFromConnection(Communicator& communicator, const Connection& connection) {
auto semaphoreFuture =
communicator.buildSemaphore(connection, communicator.remoteRankOf(*connection), communicator.tagOf(*connection));
communicator.buildSemaphore(connection, communicator.remoteRankOf(connection), communicator.tagOf(connection));
return semaphoreFuture.get();
}
@@ -121,19 +119,18 @@ MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(const Semaphore& sema
: semaphore_(semaphore),
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()),
outboundToken_(std::make_unique<uint64_t>()) {
if (connection()->localDevice().type != DeviceType::GPU) {
if (connection().localDevice().type != DeviceType::GPU) {
throw Error("Local endpoint device type of Host2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage);
}
}
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection)
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator, const Connection& connection)
: Host2DeviceSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2DeviceSemaphore::connection() const { return semaphore_.connection(); }
MSCCLPP_API_CPP Connection& Host2DeviceSemaphore::connection() { return semaphore_.connection(); }
MSCCLPP_API_CPP void Host2DeviceSemaphore::signal() {
connection()->updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
connection().updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
}
MSCCLPP_API_CPP Host2DeviceSemaphore::DeviceHandle Host2DeviceSemaphore::deviceHandle() const {
@@ -147,22 +144,21 @@ MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(const Semaphore& semaphor
: semaphore_(semaphore),
expectedInboundToken_(std::make_unique<uint64_t>()),
outboundToken_(std::make_unique<uint64_t>()) {
if (connection()->transport() == Transport::CudaIpc) {
if (connection().transport() == Transport::CudaIpc) {
throw Error("Host2HostSemaphore cannot be used with CudaIpc transport", ErrorCode::InvalidUsage);
}
if (connection()->localDevice().type != DeviceType::CPU) {
if (connection().localDevice().type != DeviceType::CPU) {
throw Error("Local endpoint device type of Host2HostSemaphore should be CPU", ErrorCode::InvalidUsage);
}
}
MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection)
MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(Communicator& communicator, const Connection& connection)
: Host2HostSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
MSCCLPP_API_CPP std::shared_ptr<Connection> Host2HostSemaphore::connection() const { return semaphore_.connection(); }
MSCCLPP_API_CPP Connection& Host2HostSemaphore::connection() { return semaphore_.connection(); }
MSCCLPP_API_CPP void Host2HostSemaphore::signal() {
connection()->updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
connection().updateAndSync(semaphore_.remoteMemory(), 0, outboundToken_.get(), *outboundToken_ + 1);
}
MSCCLPP_API_CPP bool Host2HostSemaphore::poll() {
@@ -187,18 +183,16 @@ MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(const
: semaphore_(semaphore),
expectedInboundToken_(detail::gpuCallocUnique<uint64_t>()),
outboundToken_(detail::gpuCallocUnique<uint64_t>()) {
if (connection()->localDevice().type != DeviceType::GPU) {
if (connection().localDevice().type != DeviceType::GPU) {
throw Error("Local endpoint device type of MemoryDevice2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage);
}
}
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection)
const Connection& connection)
: MemoryDevice2DeviceSemaphore(buildSemaphoreFromConnection(communicator, connection)) {}
MSCCLPP_API_CPP std::shared_ptr<Connection> MemoryDevice2DeviceSemaphore::connection() const {
return semaphore_.connection();
}
MSCCLPP_API_CPP Connection& MemoryDevice2DeviceSemaphore::connection() { return semaphore_.connection(); }
MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::DeviceHandle MemoryDevice2DeviceSemaphore::deviceHandle() const {
MemoryDevice2DeviceSemaphore::DeviceHandle device;

View File

@@ -213,7 +213,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
std::vector<mscclpp::SemaphoreId> semaphoreIds;
std::vector<mscclpp::RegisteredMemory> localMemories;
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
std::vector<std::shared_future<mscclpp::Connection>> connections(world_size);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemories;
for (int r = 0; r < world_size; ++r) {

View File

@@ -85,7 +85,7 @@ class MyProxyService {
mscclpp::RegisteredMemory localMemory_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> deviceSemaphores1_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> deviceSemaphores2_;
std::vector<std::shared_ptr<mscclpp::Connection>> connections_;
std::vector<mscclpp::Connection> connections_;
mscclpp::Proxy proxy_;
public:
@@ -98,7 +98,7 @@ class MyProxyService {
int cudaNum = rankToLocalRank(rank);
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionsFuture(world_size);
std::vector<std::shared_future<mscclpp::Connection>> connectionsFuture(world_size);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemoriesFuture(world_size);
localMemory_ = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
@@ -138,15 +138,15 @@ class MyProxyService {
int dataSizePerRank = dataSize_ / world_size;
for (int r = 1; r < world_size; ++r) {
int nghr = (rank + r) % world_size;
connections_[nghr]->write(remoteMemories_[nghr], rank * dataSizePerRank, localMemory_, rank * dataSizePerRank,
dataSizePerRank);
connections_[nghr].write(remoteMemories_[nghr], rank * dataSizePerRank, localMemory_, rank * dataSizePerRank,
dataSizePerRank);
if (triggerRaw.fst == 1)
deviceSemaphores1_[nghr]->signal();
else
deviceSemaphores2_[nghr]->signal();
if ((flusher % 64) == 0 && mscclpp::AllIBTransports.has(connections_[nghr]->transport())) {
if ((flusher % 64) == 0 && mscclpp::AllIBTransports.has(connections_[nghr].transport())) {
// if we are using IB transport, we need a flush every once in a while
connections_[nghr]->flush();
connections_[nghr].flush();
}
}
flusher++;

View File

@@ -44,8 +44,8 @@ void CommunicatorTestBase::TearDown() {
void CommunicatorTestBase::setNumRanksToUse(int num) { numRanksToUse = num; }
void CommunicatorTestBase::connectMesh(bool useIpc, bool useIb, bool useEthernet) {
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(numRanksToUse);
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> cpuConnectionFutures(numRanksToUse);
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures(numRanksToUse);
std::vector<std::shared_future<mscclpp::Connection>> cpuConnectionFutures(numRanksToUse);
for (int i = 0; i < numRanksToUse; i++) {
if (i != gEnv->rank) {
if ((rankToNode(i) == rankToNode(gEnv->rank)) && useIpc) {
@@ -158,9 +158,9 @@ void CommunicatorTest::writeToRemote(int dataCountPerRank) {
if (i != gEnv->rank) {
auto& conn = connections.at(i);
auto& peerMemory = remoteMemory[n].at(i);
conn->write(peerMemory, gEnv->rank * dataCountPerRank * sizeof(int), localMemory[n],
gEnv->rank * dataCountPerRank * sizeof(int), dataCountPerRank * sizeof(int));
conn->flush();
conn.write(peerMemory, gEnv->rank * dataCountPerRank * sizeof(int), localMemory[n],
gEnv->rank * dataCountPerRank * sizeof(int), dataCountPerRank * sizeof(int));
conn.flush();
}
}
}
@@ -262,7 +262,7 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
for (auto entry : cpuConnections) {
auto& conn = entry.second;
// Host2HostSemaphore cannot be used with CudaIpc transport
if (conn->transport() == mscclpp::Transport::CudaIpc) continue;
if (conn.transport() == mscclpp::Transport::CudaIpc) continue;
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2HostSemaphore>(*communicator.get(), conn)});
}
communicator->bootstrap()->barrier();
@@ -273,25 +273,25 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
writeToRemote(deviceBufferSize / sizeof(int) / gEnv->worldSize);
for (int i = 0; i < gEnv->worldSize; i++) {
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
if (i != gEnv->rank && connections[i].transport() != mscclpp::Transport::CudaIpc) {
semaphores[i]->signal();
}
}
for (int i = 0; i < gEnv->worldSize; i++) {
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
if (i != gEnv->rank && connections[i].transport() != mscclpp::Transport::CudaIpc) {
semaphores[i]->wait();
}
}
for (int i = 0; i < gEnv->worldSize; i++) {
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
if (i != gEnv->rank && connections[i].transport() != mscclpp::Transport::CudaIpc) {
semaphores[i]->signal();
}
}
for (int i = 0; i < gEnv->worldSize; i++) {
if (i != gEnv->rank && connections[i]->transport() != mscclpp::Transport::CudaIpc) {
if (i != gEnv->rank && connections[i].transport() != mscclpp::Transport::CudaIpc) {
semaphores[i]->wait();
}
}

View File

@@ -25,7 +25,7 @@ void MemoryChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::Memory
const bool isInPlace = (outputBuff == nullptr);
mscclpp::TransportFlags transport = mscclpp::Transport::CudaIpc;
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures(worldSize);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
mscclpp::RegisteredMemory inputBufRegMem = communicator->registerMemory(inputBuff, inputBuffBytes, transport);

View File

@@ -108,8 +108,8 @@ class CommunicatorTestBase : public MultiProcessTest {
std::shared_ptr<mscclpp::Communicator> communicator;
mscclpp::Transport ibTransport;
std::vector<mscclpp::RegisteredMemory> registeredMemories;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> connections;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cpuConnections;
std::unordered_map<int, mscclpp::Connection> connections;
std::unordered_map<int, mscclpp::Connection> cpuConnections;
};
class CommunicatorTest : public CommunicatorTestBase {

View File

@@ -28,7 +28,7 @@ void PortChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::PortChan
if (useIb) transport |= ibTransport;
if (useEthernet) transport |= mscclpp::Transport::Ethernet;
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures(worldSize);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
mscclpp::RegisteredMemory sendBufRegMem = communicator->registerMemory(sendBuff, sendBuffBytes, transport);

View File

@@ -509,7 +509,7 @@ class AllGatherProxyService : public mscclpp::BaseProxyService {
void addRemoteMemory(mscclpp::RegisteredMemory memory) { remoteMemories_.push_back(memory); }
void setLocalMemory(mscclpp::RegisteredMemory memory) { localMemory_ = memory; }
mscclpp::SemaphoreId buildAndAddSemaphore(mscclpp::Communicator& communicator,
std::shared_ptr<mscclpp::Connection> connection) {
const mscclpp::Connection& connection) {
semaphores_.push_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(communicator, connection));
return semaphores_.size() - 1;
}
@@ -554,19 +554,19 @@ mscclpp::ProxyHandlerResult AllGatherProxyService::handleTrigger(mscclpp::ProxyT
continue;
}
int index = (r < rank_) ? r : r - 1;
semaphores_[index]->connection()->write(remoteMemories_[index], offset, localMemory_, offset, sendBytes_);
semaphores_[index]->connection().write(remoteMemories_[index], offset, localMemory_, offset, sendBytes_);
semaphores_[index]->signal();
}
bool flushIpc = false;
for (auto& semaphore : semaphores_) {
auto conn = semaphore->connection();
if (conn->transport() == mscclpp::Transport::CudaIpc && !flushIpc) {
auto& conn = semaphore->connection();
if (conn.transport() == mscclpp::Transport::CudaIpc && !flushIpc) {
// since all the cudaIpc channels are using the same cuda stream, we only need to flush one of them
conn->flush();
conn.flush();
flushIpc = true;
}
if (mscclpp::AllIBTransports.has(conn->transport())) {
conn->flush();
if (mscclpp::AllIBTransports.has(conn.transport())) {
conn.flush();
}
}
return mscclpp::ProxyHandlerResult::Continue;
@@ -758,7 +758,7 @@ void AllGatherTestEngine::setupConnections() {
} else {
auto service = std::dynamic_pointer_cast<AllGatherProxyService>(chanService_);
setupMeshConnections(devPortChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0,
[&](std::vector<std::shared_ptr<mscclpp::Connection>> conns,
[&](std::vector<mscclpp::Connection> conns,
std::vector<std::shared_future<mscclpp::RegisteredMemory>>& remoteMemories,
const mscclpp::RegisteredMemory& localMemory) {
std::vector<mscclpp::SemaphoreId> semaphoreIds;

View File

@@ -364,14 +364,14 @@ std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createProxyService()
}
void BaseTestEngine::setupMeshConnectionsInternal(
std::vector<std::shared_ptr<mscclpp::Connection>>& connections, mscclpp::RegisteredMemory& localRegMemory,
std::vector<mscclpp::Connection>& connections, mscclpp::RegisteredMemory& localRegMemory,
std::vector<std::shared_future<mscclpp::RegisteredMemory>>& remoteRegMemories, bool addConnections) {
const int worldSize = args_.totalRanks;
const int rank = args_.rank;
const int nRanksPerNode = args_.nRanksPerNode;
const int thisNode = rank / nRanksPerNode;
const mscclpp::Transport ibTransport = IBs[args_.gpuNum];
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
auto rankToNode = [&](int rank) { return rank / nRanksPerNode; };
for (int r = 0; r < worldSize; r++) {
@@ -393,7 +393,7 @@ void BaseTestEngine::setupMeshConnectionsInternal(
remoteRegMemories.push_back(remoteMemory);
}
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
[](const std::shared_future<std::shared_ptr<mscclpp::Connection>>& future) { return future.get(); });
[](const std::shared_future<mscclpp::Connection>& future) { return future.get(); });
}
// Create mesh connections between all ranks. If recvBuff is nullptr, assume in-place.
@@ -409,7 +409,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Port
outputBufRegMem = comm_->registerMemory(outputBuff, outputBuffBytes, allTransports);
}
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<mscclpp::Connection> connections;
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
mscclpp::RegisteredMemory& localRegMemory = (outputBuff) ? outputBufRegMem : inputBufRegMem;
@@ -441,7 +441,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
outputBufRegMem = comm_->registerMemory(outputBuff, outputBuffBytes, allTransports);
}
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<mscclpp::Connection> connections;
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
mscclpp::RegisteredMemory& localRegMemory =
(outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem;
@@ -452,7 +452,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
std::unordered_map<size_t, std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>> memorySemaphores;
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
for (size_t i = 0; i < nChannelPerConnection; ++i) {
memorySemaphores[cid].emplace_back(
std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, connections[cid]));
@@ -462,7 +462,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
for (size_t i = 0; i < nChannelPerConnection; ++i) {
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
memoryChannels.emplace_back(memorySemaphores[cid][i], remoteRegMemories[cid].get(),
(outputBuff && semantic == ChannelSemantic::GET) ? outputBufRegMem : inputBufRegMem,
outputBuff);
@@ -492,7 +492,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
outputBufRegMem = comm_->registerMemory(outputBuff, outputBuffBytes, allTransports);
}
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<mscclpp::Connection> connections;
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
mscclpp::RegisteredMemory& localRegMemory =
(getPacketBuff) ? getPacketBufRegMem : ((outputBuff) ? outputBufRegMem : inputBufRegMem);
@@ -513,7 +513,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
auto service = std::dynamic_pointer_cast<mscclpp::ProxyService>(chanService_);
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
memorySemaphores.emplace(cid, std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, connections[cid]));
} else {
connIdToSemId[cid] = service->buildAndAddSemaphore(*comm_, connections[cid]);
@@ -521,7 +521,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
}
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
memoryChannels.emplace_back(memorySemaphores[cid],
(outputBuff) ? remoteRegMemoriesOutput[cid].get() : remoteRegMemories[cid].get(),
inputBufRegMem, (outputBuff) ? outputBufRegMem.data() : nullptr);

View File

@@ -102,15 +102,15 @@ class BaseTestEngine {
double benchTime();
void setupMeshConnectionsInternal(std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
void setupMeshConnectionsInternal(std::vector<mscclpp::Connection>& connections,
mscclpp::RegisteredMemory& localMemory,
std::vector<std::shared_future<mscclpp::RegisteredMemory>>& remoteRegMemories,
bool addConnections = true);
protected:
using SetupChannelFunc = std::function<void(std::vector<std::shared_ptr<mscclpp::Connection>>,
std::vector<std::shared_future<mscclpp::RegisteredMemory>>&,
const mscclpp::RegisteredMemory&)>;
using SetupChannelFunc =
std::function<void(std::vector<mscclpp::Connection>, std::vector<std::shared_future<mscclpp::RegisteredMemory>>&,
const mscclpp::RegisteredMemory&)>;
template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
void setupMeshConnections(std::vector<DeviceHandle<mscclpp::PortChannel>>& portChannels, void* inputBuff,