mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Update EndpointConfig interfaces (#651)
* Separate IB-specific options into a nested struct * Enable `connect()` by an `Endpoint`, not only by `EndpointConfig` * Other minor changes
This commit is contained in:
@@ -374,41 +374,53 @@ struct Device {
|
||||
int id;
|
||||
};
|
||||
|
||||
/// Used to configure an endpoint.
|
||||
/// Configuration for creating communication endpoints.
|
||||
struct EndpointConfig {
|
||||
static const int DefaultMaxCqSize = 1024;
|
||||
static const int DefaultMaxCqPollNum = 1;
|
||||
static const int DefaultMaxSendWr = 8192;
|
||||
static const int DefaultMaxWrPerSend = 64;
|
||||
/// InfiniBand-specific configuration options that control queue pair behavior and performance characteristics.
|
||||
/// These settings are only used when the transport is an InfiniBand type (IB0-IB7); they are ignored for other
|
||||
/// transports.
|
||||
struct Ib {
|
||||
static const int DefaultMaxCqSize = 1024;
|
||||
static const int DefaultMaxCqPollNum = 1;
|
||||
static const int DefaultMaxSendWr = 8192;
|
||||
static const int DefaultMaxWrPerSend = 64;
|
||||
|
||||
/// Maximum size of the completion queue.
|
||||
int maxCqSize;
|
||||
/// Maximum number of completion queue polls per operation.
|
||||
int maxCqPollNum;
|
||||
/// Maximum number of outstanding send work requests.
|
||||
int maxSendWr;
|
||||
/// Maximum number of work requests per send operation.
|
||||
int maxWrPerSend;
|
||||
|
||||
/// Constructor.
|
||||
/// @param maxCqSize Maximum completion queue size.
|
||||
/// @param maxCqPollNum Maximum completion queue poll count.
|
||||
/// @param maxSendWr Maximum outstanding send work requests.
|
||||
/// @param maxWrPerSend Maximum work requests per send operation.
|
||||
Ib(int maxCqSize = DefaultMaxCqSize, int maxCqPollNum = DefaultMaxCqPollNum, int maxSendWr = DefaultMaxSendWr,
|
||||
int maxWrPerSend = DefaultMaxWrPerSend)
|
||||
: maxCqSize(maxCqSize), maxCqPollNum(maxCqPollNum), maxSendWr(maxSendWr), maxWrPerSend(maxWrPerSend) {}
|
||||
};
|
||||
|
||||
/// Communication transport type (e.g., CudaIpc, IB0-IB7, Ethernet).
|
||||
Transport transport;
|
||||
/// Target device for the endpoint (GPU or CPU with optional device ID).
|
||||
Device device;
|
||||
int ibMaxCqSize;
|
||||
int ibMaxCqPollNum;
|
||||
int ibMaxSendWr;
|
||||
int ibMaxWrPerSend;
|
||||
/// Maximum number of write requests that can be queued (-1 for default).
|
||||
int maxWriteQueueSize;
|
||||
/// InfiniBand-specific options (used only for Transport::IBx).
|
||||
Ib ib;
|
||||
|
||||
/// Constructor that takes a transport and sets the other fields to their default values.
|
||||
///
|
||||
/// @param transport The transport to use.
|
||||
/// @param device The device to use.
|
||||
/// @param ibMaxCqSize The maximum completion queue size.
|
||||
/// @param ibMaxCqPollNum The maximum completion queue poll number.
|
||||
/// @param ibMaxSendWr The maximum send work requests.
|
||||
/// @param ibMaxWrPerSend The maximum work requests per send.
|
||||
/// @param maxWriteQueueSize The maximum write queue size.
|
||||
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
|
||||
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
|
||||
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
|
||||
int maxWriteQueueSize = -1)
|
||||
: transport(transport),
|
||||
device(device),
|
||||
ibMaxCqSize(ibMaxCqSize),
|
||||
ibMaxCqPollNum(ibMaxCqPollNum),
|
||||
ibMaxSendWr(ibMaxSendWr),
|
||||
ibMaxWrPerSend(ibMaxWrPerSend),
|
||||
maxWriteQueueSize(maxWriteQueueSize) {}
|
||||
/// Constructs endpoint configuration with specified transport, device, and optional settings.
|
||||
/// @param transport Communication transport to use.
|
||||
/// @param device Target device for the endpoint.
|
||||
/// @param maxWriteQueueSize Maximum write queue size (-1 for system default).
|
||||
/// @param ib IB-specific configuration.
|
||||
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU, int maxWriteQueueSize = -1,
|
||||
Ib ib = {})
|
||||
: transport(transport), device(device), maxWriteQueueSize(maxWriteQueueSize), ib(ib) {}
|
||||
};
|
||||
|
||||
class Context;
|
||||
@@ -423,6 +435,10 @@ class Endpoint {
|
||||
/// Constructor.
|
||||
Endpoint() = default;
|
||||
|
||||
/// Get the configuration used to create the endpoint.
|
||||
/// @return The configuration used to create the endpoint.
|
||||
const EndpointConfig& config() const;
|
||||
|
||||
/// Get the transport used.
|
||||
/// @return The transport used.
|
||||
Transport transport() const;
|
||||
@@ -685,9 +701,9 @@ class Semaphore {
|
||||
std::shared_ptr<Impl> pimpl_;
|
||||
};
|
||||
|
||||
/// Deprecated.
|
||||
template <typename T>
|
||||
using NonblockingFuture [[deprecated("Use std::shared_future instead. This will be removed in a future release.")]] =
|
||||
std::shared_future<T>;
|
||||
using NonblockingFuture = std::shared_future<T>;
|
||||
|
||||
/// A class that sets up all registered memories and connections between processes.
|
||||
///
|
||||
@@ -853,12 +869,20 @@ class Communicator {
|
||||
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
|
||||
/// back to back.
|
||||
///
|
||||
/// @param localConfig The configuration for the local endpoint.
|
||||
/// @param localEndpoint The local endpoint.
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the send and receive.
|
||||
/// @return A future of shared pointer to the connection.
|
||||
///
|
||||
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0);
|
||||
std::shared_future<std::shared_ptr<Connection>> connect(const Endpoint& localEndpoint, int remoteRank, int tag = 0);
|
||||
|
||||
/// Connect to a remote rank. Wrapper of `connect(localEndpoint, remoteRank, tag)`.
|
||||
/// @param localConfig The configuration for the local endpoint.
|
||||
/// @param remoteRank The rank of the remote process.
|
||||
/// @param tag The tag to use for identifying the send and receive.
|
||||
/// @return A future of shared pointer to the connection.
|
||||
std::shared_future<std::shared_ptr<Connection>> connect(const EndpointConfig& localConfig, int remoteRank,
|
||||
int tag = 0);
|
||||
|
||||
[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
|
||||
shared_future<std::shared_ptr<Connection>>
|
||||
|
||||
@@ -124,6 +124,17 @@ void register_core(nb::module_& m) {
|
||||
.def_rw("id", &Device::id)
|
||||
.def("__str__", [](const Device& self) { return std::to_string(self); });
|
||||
|
||||
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<int, int, int, int>(), nb::arg("maxCqSize") = EndpointConfig::Ib::DefaultMaxCqSize,
|
||||
nb::arg("maxCqPollNum") = EndpointConfig::Ib::DefaultMaxCqPollNum,
|
||||
nb::arg("maxSendWr") = EndpointConfig::Ib::DefaultMaxSendWr,
|
||||
nb::arg("maxWrPerSend") = EndpointConfig::Ib::DefaultMaxWrPerSend)
|
||||
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
|
||||
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
|
||||
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
|
||||
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
|
||||
@@ -158,17 +169,23 @@ void register_core(nb::module_& m) {
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, Device, int, int, int, int, int>(), nb::arg("transport"), nb::arg("device"),
|
||||
nb::arg("ibMaxCqSize") = EndpointConfig::DefaultMaxCqSize,
|
||||
nb::arg("ibMaxCqPollNum") = EndpointConfig::DefaultMaxCqPollNum,
|
||||
nb::arg("ibMaxSendWr") = EndpointConfig::DefaultMaxSendWr,
|
||||
nb::arg("ibMaxWrPerSend") = EndpointConfig::DefaultMaxWrPerSend, nb::arg("maxWriteQueueSize") = -1)
|
||||
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
|
||||
nb::arg("maxWriteQueueSize") = -1, nb::arg("ib") = EndpointConfig::Ib{})
|
||||
.def_rw("transport", &EndpointConfig::transport)
|
||||
.def_rw("device", &EndpointConfig::device)
|
||||
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
|
||||
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
|
||||
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
|
||||
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend)
|
||||
.def_rw("ib", &EndpointConfig::ib)
|
||||
.def_prop_rw(
|
||||
"ib_max_cq_size", [](EndpointConfig& self) { return self.ib.maxCqSize; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxCqSize = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_cq_poll_num", [](EndpointConfig& self) { return self.ib.maxCqPollNum; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxCqPollNum = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
|
||||
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
|
||||
|
||||
nb::class_<Context>(m, "Context")
|
||||
@@ -212,13 +229,15 @@ void register_core(nb::module_& m) {
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("connect",
|
||||
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
|
||||
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const Endpoint&, int, int)>(
|
||||
&Communicator::connect),
|
||||
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
nb::arg("localEndpoint"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("connect", [](Communicator* self, const EndpointConfig& localConfig, int remoteRank,
|
||||
int tag = 0) { return self->connect(localConfig, remoteRank, tag); })
|
||||
.def(
|
||||
"connect",
|
||||
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return self->connect(std::move(localConfig), remoteRank, tag);
|
||||
[](Communicator* self, int remoteRank, int tag, const EndpointConfig& localConfig) {
|
||||
return self->connect(localConfig, remoteRank, tag);
|
||||
},
|
||||
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def(
|
||||
|
||||
@@ -99,41 +99,44 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(EndpointConfig localConfig,
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const Endpoint& localEndpoint,
|
||||
int remoteRank, int tag) {
|
||||
auto localEndpoint = context()->createEndpoint(localConfig);
|
||||
|
||||
if (remoteRank == bootstrap()->getRank()) {
|
||||
// Connection to self
|
||||
auto remoteEndpoint = context()->createEndpoint(localConfig);
|
||||
auto remoteEndpoint = context()->createEndpoint(localEndpoint.config());
|
||||
auto connection = context()->connect(localEndpoint, remoteEndpoint);
|
||||
std::promise<std::shared_ptr<Connection>> promise;
|
||||
promise.set_value(connection);
|
||||
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
|
||||
return std::shared_future<std::shared_ptr<Connection>>(std::move(promise.get_future()));
|
||||
return std::shared_future<std::shared_ptr<Connection>>(promise.get_future());
|
||||
}
|
||||
|
||||
bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);
|
||||
|
||||
auto future =
|
||||
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
|
||||
localEndpoint = std::move(localEndpoint)]() mutable {
|
||||
if (lastRecvItem) {
|
||||
// Recursive call to the previous receive items
|
||||
lastRecvItem->wait();
|
||||
}
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
auto remoteEndpoint = Endpoint::deserialize(data);
|
||||
auto connection = context()->connect(localEndpoint, remoteEndpoint);
|
||||
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
|
||||
return connection;
|
||||
});
|
||||
auto future = std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint,
|
||||
lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() mutable {
|
||||
if (lastRecvItem) {
|
||||
// Recursive call to the previous receive items
|
||||
lastRecvItem->wait();
|
||||
}
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
auto remoteEndpoint = Endpoint::deserialize(data);
|
||||
auto connection = context()->connect(localEndpoint, remoteEndpoint);
|
||||
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
|
||||
return connection;
|
||||
});
|
||||
auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move(future));
|
||||
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<std::shared_ptr<Connection>>>(shared_future));
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const EndpointConfig& localConfig,
|
||||
int remoteRank, int tag) {
|
||||
auto localEndpoint = context()->createEndpoint(localConfig);
|
||||
return connect(localEndpoint, remoteRank, tag);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
|
||||
EndpointConfig localConfig) {
|
||||
return connect(localConfig, remoteRank, tag);
|
||||
|
||||
@@ -167,9 +167,6 @@ IBConnection::IBConnection(std::shared_ptr<Context> context, const Endpoint& loc
|
||||
transport_(localEndpoint.transport()),
|
||||
remoteTransport_(remoteEndpoint.transport()),
|
||||
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
|
||||
if (maxWriteQueueSize_ == -1) {
|
||||
maxWriteQueueSize_ = EndpointConfig::DefaultMaxCqSize;
|
||||
}
|
||||
qp_ = getImpl(localEndpoint).ibQp_;
|
||||
qp_.lock()->rtr(getImpl(remoteEndpoint).ibQpInfo_);
|
||||
qp_.lock()->rts();
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "context.hpp"
|
||||
|
||||
#include <mscclpp/env.hpp>
|
||||
#include <sstream>
|
||||
|
||||
#include "api.h"
|
||||
#include "connection.hpp"
|
||||
@@ -76,21 +77,21 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &loc
|
||||
if (remoteEndpoint.device().type == DeviceType::GPU && remoteEndpoint.device().id < 0) {
|
||||
throw Error("No GPU device ID provided for remote endpoint", ErrorCode::InvalidUsage);
|
||||
}
|
||||
auto localTransport = localEndpoint.transport();
|
||||
auto remoteTransport = remoteEndpoint.transport();
|
||||
if (localTransport != remoteTransport &&
|
||||
!(AllIBTransports.has(localTransport) && AllIBTransports.has(remoteTransport))) {
|
||||
std::stringstream ss;
|
||||
ss << "Transport mismatch between local (" << std::to_string(localTransport) << ") and remote ("
|
||||
<< std::to_string(remoteEndpoint.transport()) << ") endpoints";
|
||||
throw Error(ss.str(), ErrorCode::InvalidUsage);
|
||||
}
|
||||
std::shared_ptr<Connection> conn;
|
||||
if (localEndpoint.transport() == Transport::CudaIpc) {
|
||||
if (remoteEndpoint.transport() != Transport::CudaIpc) {
|
||||
throw Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
if (localTransport == Transport::CudaIpc) {
|
||||
conn = std::make_shared<CudaIpcConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
|
||||
} else if (AllIBTransports.has(localEndpoint.transport())) {
|
||||
if (!AllIBTransports.has(remoteEndpoint.transport())) {
|
||||
throw Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
} else if (AllIBTransports.has(localTransport)) {
|
||||
conn = std::make_shared<IBConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
|
||||
} else if (localEndpoint.transport() == Transport::Ethernet) {
|
||||
if (remoteEndpoint.transport() != Transport::Ethernet) {
|
||||
throw Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
|
||||
}
|
||||
} else if (localTransport == Transport::Ethernet) {
|
||||
conn = std::make_shared<EthernetConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
|
||||
} else {
|
||||
throw Error("Unsupported transport", ErrorCode::InternalError);
|
||||
|
||||
@@ -13,21 +13,21 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
|
||||
: transport_(config.transport),
|
||||
device_(config.device),
|
||||
hostHash_(getHostHash()),
|
||||
pidHash_(getPidHash()),
|
||||
maxWriteQueueSize_(config.maxWriteQueueSize) {
|
||||
if (device_.type == DeviceType::GPU && device_.id < 0) {
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&(device_.id)));
|
||||
Endpoint::Impl::Impl(const EndpointConfig& config, Context::Impl& contextImpl)
|
||||
: config_(config), hostHash_(getHostHash()), pidHash_(getPidHash()) {
|
||||
if (config_.device.type == DeviceType::GPU && config_.device.id < 0) {
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&(config_.device.id)));
|
||||
}
|
||||
if (AllIBTransports.has(transport_)) {
|
||||
if (AllIBTransports.has(config_.transport)) {
|
||||
ibLocal_ = true;
|
||||
ibQp_ = contextImpl.getIbContext(transport_)
|
||||
->createQp(config.ibMaxCqSize, config.ibMaxCqPollNum, config.ibMaxSendWr, 0, config.ibMaxWrPerSend);
|
||||
if (config_.maxWriteQueueSize <= 0) {
|
||||
config_.maxWriteQueueSize = config_.ib.maxCqSize;
|
||||
}
|
||||
ibQp_ =
|
||||
contextImpl.getIbContext(config_.transport)
|
||||
->createQp(config_.ib.maxCqSize, config_.ib.maxCqPollNum, config_.ib.maxSendWr, 0, config_.ib.maxWrPerSend);
|
||||
ibQpInfo_ = ibQp_->getInfo();
|
||||
} else if (transport_ == Transport::Ethernet) {
|
||||
} else if (config_.transport == Transport::Ethernet) {
|
||||
// Configuring Ethernet Interfaces
|
||||
abortFlag_ = 0;
|
||||
int ret = FindInterfaces(netIfName_, &socketAddress_, MAX_IF_NAME_SIZE, 1);
|
||||
@@ -42,41 +42,42 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
|
||||
|
||||
Endpoint::Impl::Impl(const std::vector<char>& serialization) {
|
||||
auto it = serialization.begin();
|
||||
it = detail::deserialize(it, transport_);
|
||||
it = detail::deserialize(it, device_);
|
||||
it = detail::deserialize(it, config_);
|
||||
it = detail::deserialize(it, hostHash_);
|
||||
it = detail::deserialize(it, pidHash_);
|
||||
if (AllIBTransports.has(transport_)) {
|
||||
if (AllIBTransports.has(config_.transport)) {
|
||||
ibLocal_ = false;
|
||||
it = detail::deserialize(it, ibQpInfo_);
|
||||
}
|
||||
if (transport_ == Transport::Ethernet) {
|
||||
} else if (config_.transport == Transport::Ethernet) {
|
||||
it = detail::deserialize(it, socketAddress_);
|
||||
}
|
||||
if (it != serialization.end()) {
|
||||
throw Error("Endpoint deserialization failed", ErrorCode::Aborted);
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<Endpoint::Impl> pimpl) : pimpl_(pimpl) {}
|
||||
|
||||
MSCCLPP_API_CPP Transport Endpoint::transport() const { return pimpl_->transport_; }
|
||||
MSCCLPP_API_CPP const EndpointConfig& Endpoint::config() const { return pimpl_->config_; }
|
||||
|
||||
MSCCLPP_API_CPP const Device& Endpoint::device() const { return pimpl_->device_; }
|
||||
MSCCLPP_API_CPP Transport Endpoint::transport() const { return pimpl_->config_.transport; }
|
||||
|
||||
MSCCLPP_API_CPP const Device& Endpoint::device() const { return pimpl_->config_.device; }
|
||||
|
||||
MSCCLPP_API_CPP uint64_t Endpoint::hostHash() const { return pimpl_->hostHash_; }
|
||||
|
||||
MSCCLPP_API_CPP uint64_t Endpoint::pidHash() const { return pimpl_->pidHash_; }
|
||||
|
||||
MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() const { return pimpl_->maxWriteQueueSize_; }
|
||||
MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() const { return pimpl_->config_.maxWriteQueueSize; }
|
||||
|
||||
MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() const {
|
||||
std::vector<char> data;
|
||||
detail::serialize(data, pimpl_->transport_);
|
||||
detail::serialize(data, pimpl_->device_);
|
||||
detail::serialize(data, pimpl_->config_);
|
||||
detail::serialize(data, pimpl_->hostHash_);
|
||||
detail::serialize(data, pimpl_->pidHash_);
|
||||
if (AllIBTransports.has(pimpl_->transport_)) {
|
||||
if (AllIBTransports.has(pimpl_->config_.transport)) {
|
||||
detail::serialize(data, pimpl_->ibQpInfo_);
|
||||
}
|
||||
if ((pimpl_->transport_) == Transport::Ethernet) {
|
||||
} else if (pimpl_->config_.transport == Transport::Ethernet) {
|
||||
detail::serialize(data, pimpl_->socketAddress_);
|
||||
}
|
||||
return data;
|
||||
|
||||
@@ -36,7 +36,7 @@ class RecvItem : public BaseRecvItem {
|
||||
|
||||
class LocalRecvMemory {
|
||||
public:
|
||||
LocalRecvMemory() : future_(std::move(promise_.get_future())) {}
|
||||
LocalRecvMemory() : future_(promise_.get_future()) {}
|
||||
|
||||
void set(RegisteredMemory memory) { promise_.set_value(std::move(memory)); }
|
||||
|
||||
|
||||
@@ -15,14 +15,12 @@
|
||||
namespace mscclpp {
|
||||
|
||||
struct Endpoint::Impl {
|
||||
Impl(EndpointConfig config, Context::Impl& contextImpl);
|
||||
Impl(const EndpointConfig& config, Context::Impl& contextImpl);
|
||||
Impl(const std::vector<char>& serialization);
|
||||
|
||||
Transport transport_;
|
||||
Device device_;
|
||||
EndpointConfig config_;
|
||||
uint64_t hostHash_;
|
||||
uint64_t pidHash_;
|
||||
int maxWriteQueueSize_;
|
||||
|
||||
// The following are only used for IB and are undefined for other transports.
|
||||
bool ibLocal_;
|
||||
|
||||
Reference in New Issue
Block a user