mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
New semaphore constructors (#559)
More intuitive interfaces for creating semaphores and channels. Also allows channel construction using third-party bootstrappers directly without overriding MSCCL++ Bootstrap.
This commit is contained in:
@@ -14,6 +14,8 @@ from ._mscclpp import (
|
||||
CudaError,
|
||||
CuError,
|
||||
IbError,
|
||||
Device,
|
||||
DeviceType,
|
||||
Communicator,
|
||||
Connection,
|
||||
connect_nvls_collective,
|
||||
@@ -43,6 +45,8 @@ from ._mscclpp import (
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Device",
|
||||
"DeviceType",
|
||||
"Communicator",
|
||||
"Connection",
|
||||
"connect_nvls_collective",
|
||||
|
||||
@@ -101,7 +101,7 @@ class CommGroup:
|
||||
if endpoint.transport == Transport.Nvls:
|
||||
return connect_nvls_collective(self.communicator, all_ranks, 2**30)
|
||||
else:
|
||||
connections[rank] = self.communicator.connect(rank, 0, endpoint)
|
||||
connections[rank] = self.communicator.connect(endpoint, rank)
|
||||
connections = {rank: connections[rank].get() for rank in connections}
|
||||
return connections
|
||||
|
||||
@@ -124,8 +124,8 @@ class CommGroup:
|
||||
all_registered_memories[self.my_rank] = local_reg_memory
|
||||
future_memories = {}
|
||||
for rank in connections:
|
||||
self.communicator.send_memory(local_reg_memory, rank, 0)
|
||||
future_memories[rank] = self.communicator.recv_memory(rank, 0)
|
||||
self.communicator.send_memory(local_reg_memory, rank)
|
||||
future_memories[rank] = self.communicator.recv_memory(rank)
|
||||
for rank in connections:
|
||||
all_registered_memories[rank] = future_memories[rank].get()
|
||||
return all_registered_memories
|
||||
|
||||
@@ -112,6 +112,19 @@ void register_core(nb::module_& m) {
|
||||
.def(nb::self == nb::self)
|
||||
.def(nb::self != nb::self);
|
||||
|
||||
nb::enum_<DeviceType>(m, "DeviceType")
|
||||
.value("Unknown", DeviceType::Unknown)
|
||||
.value("CPU", DeviceType::CPU)
|
||||
.value("GPU", DeviceType::GPU);
|
||||
|
||||
nb::class_<Device>(m, "Device")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<DeviceType>(), nb::arg("type"))
|
||||
.def(nb::init<DeviceType, int>(), nb::arg("type"), nb::arg("id") = -1)
|
||||
.def_rw("type", &Device::type)
|
||||
.def_rw("id", &Device::id)
|
||||
.def("__str__", [](const Device& self) { return std::to_string(self); });
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", &RegisteredMemory::data)
|
||||
@@ -120,6 +133,13 @@ void register_core(nb::module_& m) {
|
||||
.def("serialize", &RegisteredMemory::serialize)
|
||||
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("device", &Endpoint::device)
|
||||
.def("max_write_queue_size", &Endpoint::maxWriteQueueSize)
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Connection>(m, "Connection")
|
||||
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
|
||||
nb::arg("size"))
|
||||
@@ -131,21 +151,26 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
|
||||
.def("flush", &Connection::flush, nb::call_guard<nb::gil_scoped_release>(), nb::arg("timeoutUsec") = (int64_t)3e7)
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport);
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
|
||||
.def("remote_transport", &Connection::remoteTransport)
|
||||
.def("context", &Connection::context)
|
||||
.def("local_device", &Connection::localDevice)
|
||||
.def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize);
|
||||
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, Device, int, int, int, int, int>(), nb::arg("transport"), nb::arg("device"),
|
||||
nb::arg("ibMaxCqSize") = EndpointConfig::DefaultMaxCqSize,
|
||||
nb::arg("ibMaxCqPollNum") = EndpointConfig::DefaultMaxCqPollNum,
|
||||
nb::arg("ibMaxSendWr") = EndpointConfig::DefaultMaxSendWr,
|
||||
nb::arg("ibMaxWrPerSend") = EndpointConfig::DefaultMaxWrPerSend, nb::arg("maxWriteQueueSize") = -1)
|
||||
.def_rw("transport", &EndpointConfig::transport)
|
||||
.def_rw("device", &EndpointConfig::device)
|
||||
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
|
||||
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
|
||||
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
|
||||
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend);
|
||||
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend)
|
||||
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
|
||||
|
||||
nb::class_<Context>(m, "Context")
|
||||
.def_static("create", &Context::create)
|
||||
@@ -158,6 +183,19 @@ void register_core(nb::module_& m) {
|
||||
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
|
||||
.def(nb::init<std::shared_ptr<Connection>>(), nb::arg("connection"))
|
||||
.def("memory", &SemaphoreStub::memory)
|
||||
.def("serialize", &SemaphoreStub::serialize)
|
||||
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Semaphore>(m, "Semaphore")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("localStub"), nb::arg("remoteStub"))
|
||||
.def("connection", &Semaphore::connection)
|
||||
.def("local_memory", &Semaphore::localMemory)
|
||||
.def("remote_memory", &Semaphore::remoteMemory);
|
||||
|
||||
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_shared_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
|
||||
@@ -172,12 +210,28 @@ void register_core(nb::module_& m) {
|
||||
return self->registerMemory((void*)ptr, size, transports);
|
||||
},
|
||||
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("connect",
|
||||
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
|
||||
&Communicator::connect),
|
||||
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def(
|
||||
"connect",
|
||||
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return self->connect(std::move(localConfig), remoteRank, tag);
|
||||
},
|
||||
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def(
|
||||
"connect_on_setup",
|
||||
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return self->connect(std::move(localConfig), remoteRank, tag);
|
||||
},
|
||||
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("localFlag"), nb::arg("remoteRank"),
|
||||
nb::arg("tag") = 0)
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
.def("setup", [](Communicator*) {});
|
||||
|
||||
@@ -11,12 +11,10 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_memory_channel(nb::module_& m) {
|
||||
nb::class_<BaseMemoryChannel> baseMemoryChannel(m, "BaseMemoryChannel");
|
||||
baseMemoryChannel
|
||||
.def("__init__",
|
||||
[](BaseMemoryChannel* baseMemoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore) {
|
||||
new (baseMemoryChannel) BaseMemoryChannel(semaphore);
|
||||
})
|
||||
nb::class_<BaseMemoryChannel>(m, "BaseMemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<std::shared_ptr<MemoryDevice2DeviceSemaphore>>(), nb::arg("semaphore"))
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def("device_handle", &BaseMemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "BaseMemoryChannelDeviceHandle")
|
||||
@@ -26,8 +24,8 @@ void register_memory_channel(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<MemoryChannel> memoryChannel(m, "MemoryChannel");
|
||||
memoryChannel
|
||||
nb::class_<MemoryChannel>(m, "MemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst,
|
||||
|
||||
@@ -20,13 +20,19 @@ void register_port_channel(nb::module_& m) {
|
||||
.def("start_proxy", &ProxyService::startProxy)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
|
||||
.def("add_semaphore", static_cast<SemaphoreId (ProxyService::*)(const Semaphore&)>(&ProxyService::addSemaphore),
|
||||
nb::arg("semaphore"))
|
||||
.def("add_semaphore",
|
||||
static_cast<SemaphoreId (ProxyService::*)(std::shared_ptr<Host2DeviceSemaphore>)>(
|
||||
&ProxyService::addSemaphore),
|
||||
nb::arg("semaphore"))
|
||||
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
|
||||
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
|
||||
.def("base_port_channel", &ProxyService::basePortChannel, nb::arg("id"))
|
||||
.def("port_channel", &ProxyService::portChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
|
||||
|
||||
nb::class_<BasePortChannel>(m, "BasePortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
|
||||
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
|
||||
.def("device_handle", &BasePortChannel::deviceHandle);
|
||||
@@ -41,6 +47,7 @@ void register_port_channel(nb::module_& m) {
|
||||
});
|
||||
|
||||
nb::class_<PortChannel>(m, "PortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
|
||||
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
|
||||
.def("device_handle", &PortChannel::deviceHandle);
|
||||
|
||||
@@ -11,7 +11,7 @@ using namespace mscclpp;
|
||||
|
||||
void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
|
||||
host2DeviceSemaphore
|
||||
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2DeviceSemaphore::connection)
|
||||
.def("signal", &Host2DeviceSemaphore::signal)
|
||||
@@ -19,13 +19,14 @@ void register_semaphore(nb::module_& m) {
|
||||
|
||||
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_rw("inbound_token", &Host2DeviceSemaphore::DeviceHandle::inboundToken)
|
||||
.def_rw("expected_inbound_token", &Host2DeviceSemaphore::DeviceHandle::expectedInboundToken)
|
||||
.def_prop_ro("raw", [](const Host2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
.def("signal", &Host2HostSemaphore::signal)
|
||||
@@ -34,16 +35,17 @@ void register_semaphore(nb::module_& m) {
|
||||
nb::arg("max_spin_count") = 10000000);
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
|
||||
memoryDevice2DeviceSemaphore
|
||||
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
|
||||
.def("device_handle", &MemoryDevice2DeviceSemaphore::deviceHandle);
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("outboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
|
||||
.def_rw("remoteInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
|
||||
.def_rw("expectedInboundSemaphoreId", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_rw("inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundToken)
|
||||
.def_rw("outbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundToken)
|
||||
.def_rw("remote_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundToken)
|
||||
.def_rw("expected_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundToken)
|
||||
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user