mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Fix Python bindings and tests (#690)
Minimal fix to make things work. We need a more careful look at preventing silent fallback of nanobind when it fails to (properly) construct a C++ STL object with mscclpp instances.
This commit is contained in:
@@ -216,6 +216,7 @@ void register_core(nb::module_& m) {
|
||||
|
||||
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_shared_future<Connection>(m, "Connection");
|
||||
def_shared_future<Semaphore>(m, "Semaphore");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
@@ -242,7 +243,7 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("remote_rank"), nb::arg("tag"), nb::arg("local_config"))
|
||||
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag"))
|
||||
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("local_flag"), nb::arg("remote_rank"),
|
||||
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("connection"), nb::arg("remote_rank"),
|
||||
nb::arg("tag") = 0)
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
|
||||
@@ -26,14 +26,20 @@ void register_memory_channel(nb::module_& m) {
|
||||
|
||||
nb::class_<MemoryChannel>(m, "MemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, RegisteredMemory src) { new (memoryChannel) MemoryChannel(semaphore, dst, src); })
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
|
||||
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
|
||||
})
|
||||
.def(
|
||||
"__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
|
||||
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
|
||||
},
|
||||
nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0)
|
||||
.def(
|
||||
"__init__",
|
||||
[](MemoryChannel* memoryChannel, const Semaphore& semaphore, RegisteredMemory dst, RegisteredMemory src,
|
||||
uintptr_t packet_buffer = 0) {
|
||||
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
|
||||
},
|
||||
nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0)
|
||||
.def("device_handle", &MemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
|
||||
|
||||
@@ -47,6 +47,7 @@ from ._mscclpp import (
|
||||
connect_nvls_collective,
|
||||
EndpointConfig,
|
||||
Fifo,
|
||||
Semaphore,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
numa,
|
||||
@@ -79,6 +80,7 @@ __all__ = [
|
||||
"connect_nvls_collective",
|
||||
"EndpointConfig",
|
||||
"Fifo",
|
||||
"Semaphore",
|
||||
"Host2DeviceSemaphore",
|
||||
"Host2HostSemaphore",
|
||||
"numa",
|
||||
|
||||
@@ -10,6 +10,7 @@ from ._mscclpp import (
|
||||
Connection,
|
||||
connect_nvls_collective,
|
||||
EndpointConfig,
|
||||
Semaphore,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
ProxyService,
|
||||
@@ -133,18 +134,14 @@ class CommGroup:
|
||||
all_registered_memories[rank] = future_memories[rank].get()
|
||||
return all_registered_memories
|
||||
|
||||
def make_semaphore(
|
||||
self,
|
||||
connections: dict[int, Connection],
|
||||
semaphore_type: Type[Host2HostSemaphore] | Type[Host2DeviceSemaphore] | Type[MemoryDevice2DeviceSemaphore],
|
||||
) -> dict[int, Host2HostSemaphore]:
|
||||
semaphores = {}
|
||||
def make_semaphores(self, connections: dict[int, Connection]) -> dict[int, Semaphore]:
|
||||
future_semaphores = {}
|
||||
for rank in connections:
|
||||
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
|
||||
return semaphores
|
||||
future_semaphores[rank] = self.communicator.build_semaphore(connections[rank], rank)
|
||||
return {rank: future.get() for rank, future in future_semaphores.items()}
|
||||
|
||||
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
|
||||
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
channels = {}
|
||||
for rank in connections:
|
||||
@@ -159,7 +156,7 @@ class CommGroup:
|
||||
registeredScratchBuffer: RegisteredMemory,
|
||||
connections: dict[int, Connection],
|
||||
) -> dict[int, MemoryChannel]:
|
||||
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
|
||||
channels = {}
|
||||
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
|
||||
@@ -177,7 +174,7 @@ class CommGroup:
|
||||
def make_port_channels(
|
||||
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
|
||||
) -> dict[int, PortChannel]:
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
memory_ids = {}
|
||||
semaphore_ids = {}
|
||||
@@ -210,7 +207,7 @@ class CommGroup:
|
||||
)
|
||||
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
|
||||
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
|
||||
memory_ids = {}
|
||||
semaphore_ids = {}
|
||||
@@ -229,7 +226,7 @@ class CommGroup:
|
||||
def register_semaphore_with_proxy(
|
||||
self, proxy_service: ProxyService, connections: dict[int, Connection]
|
||||
) -> dict[int, PortChannel]:
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
semaphore_ids = {}
|
||||
for rank in semaphores:
|
||||
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
|
||||
|
||||
@@ -453,7 +453,10 @@ class MscclppAllReduce6:
|
||||
)
|
||||
|
||||
# create a memory_channel for each remote neighbor
|
||||
self.semaphores = group.make_semaphore(self.nvlink_connections, MemoryDevice2DeviceSemaphore)
|
||||
self.semaphores = {
|
||||
rank: MemoryDevice2DeviceSemaphore(sema)
|
||||
for rank, sema in group.make_semaphores(self.nvlink_connections).items()
|
||||
}
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu",
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/numa.hpp>
|
||||
#include <mscclpp/proxy.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
#include <vector>
|
||||
@@ -19,37 +18,39 @@ namespace nb = nanobind;
|
||||
|
||||
class MyProxyService {
|
||||
private:
|
||||
int deviceNumaNode_;
|
||||
int my_rank_, nranks_, dataSize_;
|
||||
std::vector<mscclpp::Connection> connections_;
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
|
||||
std::vector<mscclpp::RegisteredMemory> allRegMem_;
|
||||
std::vector<mscclpp::Host2DeviceSemaphore> semaphores_;
|
||||
mscclpp::Proxy proxy_;
|
||||
|
||||
public:
|
||||
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<mscclpp::Connection> conns,
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem,
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores)
|
||||
MyProxyService(int my_rank, int nranks, int dataSize, nb::list allRegMemList, nb::list semaphoreList)
|
||||
: my_rank_(my_rank),
|
||||
nranks_(nranks),
|
||||
dataSize_(dataSize),
|
||||
connections_(conns),
|
||||
allRegMem_(allRegMem),
|
||||
semaphores_(semaphores),
|
||||
proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
deviceNumaNode_ = mscclpp::getDeviceNumaNode(cudaDevice);
|
||||
allRegMem_.reserve(allRegMemList.size());
|
||||
for (size_t i = 0; i < allRegMemList.size(); ++i) {
|
||||
auto& regMem = nb::cast<const mscclpp::RegisteredMemory&>(allRegMemList[i]);
|
||||
allRegMem_.push_back(regMem);
|
||||
}
|
||||
semaphores_.reserve(semaphoreList.size());
|
||||
for (size_t i = 0; i < semaphoreList.size(); ++i) {
|
||||
auto& sema = nb::cast<const mscclpp::Semaphore&>(semaphoreList[i]);
|
||||
semaphores_.emplace_back(sema);
|
||||
}
|
||||
}
|
||||
|
||||
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger) {
|
||||
int dataSizePerRank = dataSize_ / nranks_;
|
||||
for (int r = 1; r < nranks_; ++r) {
|
||||
int nghr = (my_rank_ + r) % nranks_;
|
||||
connections_[nghr].write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
|
||||
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
|
||||
semaphores_[nghr]->signal();
|
||||
connections_[nghr].flush();
|
||||
auto& sema = semaphores_[nghr];
|
||||
auto& conn = sema.connection();
|
||||
conn.write(allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, allRegMem_[my_rank_],
|
||||
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
|
||||
sema.signal();
|
||||
conn.flush();
|
||||
}
|
||||
return mscclpp::ProxyHandlerResult::Continue;
|
||||
}
|
||||
@@ -61,16 +62,11 @@ class MyProxyService {
|
||||
mscclpp::FifoDeviceHandle fifoDeviceHandle() { return proxy_.fifo()->deviceHandle(); }
|
||||
};
|
||||
|
||||
void init_mscclpp_proxy_test_module(nb::module_ &m) {
|
||||
NB_MODULE(_ext, m) {
|
||||
nb::class_<MyProxyService>(m, "MyProxyService")
|
||||
.def(nb::init<int, int, int, std::vector<mscclpp::Connection>,
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>>,
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>>>(),
|
||||
nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"),
|
||||
nb::arg("h2d_sem_vec"))
|
||||
.def(nb::init<int, int, int, nb::list, nb::list>(), nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"),
|
||||
nb::arg("reg_mem_list"), nb::arg("sem_list"))
|
||||
.def("fifo_device_handle", &MyProxyService::fifoDeviceHandle)
|
||||
.def("start", &MyProxyService::start)
|
||||
.def("stop", &MyProxyService::stop);
|
||||
}
|
||||
|
||||
NB_MODULE(_ext, m) { init_mscclpp_proxy_test_module(m); }
|
||||
|
||||
@@ -290,7 +290,8 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
|
||||
connections = {rank: conn.get() for rank, conn in connections.items()}
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
semaphores = group.make_semaphores(connections)
|
||||
semaphores = {rank: Host2HostSemaphore(sema) for rank, sema in semaphores.items()}
|
||||
for rank in connections:
|
||||
semaphores[rank].signal()
|
||||
|
||||
@@ -309,7 +310,8 @@ def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
|
||||
connections = {rank: conn.get() for rank, conn in connections.items()}
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
semaphores = group.make_semaphores(connections)
|
||||
semaphores = {rank: Host2HostSemaphore(sema) for rank, sema in semaphores.items()}
|
||||
|
||||
def target_wait(sems, conns):
|
||||
for rank in conns:
|
||||
@@ -457,7 +459,8 @@ def test_h2d_semaphores(mpi_group: MpiGroup, connection_type: str):
|
||||
|
||||
group, connections = create_group_and_connection(mpi_group, connection_type)
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphores = group.make_semaphores(connections)
|
||||
semaphores = {rank: Host2DeviceSemaphore(sema) for rank, sema in semaphores.items()}
|
||||
kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores)
|
||||
kernel()
|
||||
|
||||
@@ -473,7 +476,8 @@ def test_h2d_semaphores(mpi_group: MpiGroup, connection_type: str):
|
||||
def test_d2d_semaphores(mpi_group: MpiGroup):
|
||||
group, connections = create_group_and_connection(mpi_group, "NVLink")
|
||||
|
||||
semaphores = group.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
semaphores = group.make_semaphores(connections)
|
||||
semaphores = {rank: MemoryDevice2DeviceSemaphore(sema) for rank, sema in semaphores.items()}
|
||||
group.barrier()
|
||||
kernel = MscclppKernel("d2d_semaphore", group.my_rank, group.nranks, semaphores)
|
||||
kernel()
|
||||
@@ -545,29 +549,29 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, connection_type: str):
|
||||
group.barrier()
|
||||
all_reg_memories = group.register_tensor_with_connections(memory, connections)
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphores = group.make_semaphores(connections)
|
||||
|
||||
list_conn = []
|
||||
list_sem = []
|
||||
list_reg_mem = []
|
||||
first_conn = next(iter(connections.values()))
|
||||
first_sem = next(iter(semaphores.values()))
|
||||
for rank in range(group.nranks):
|
||||
if rank in connections:
|
||||
list_conn.append(connections[rank])
|
||||
list_sem.append(semaphores[rank])
|
||||
else:
|
||||
list_conn.append(first_conn) # just for simplicity of indexing
|
||||
list_sem.append(first_sem)
|
||||
|
||||
list_reg_mem.append(all_reg_memories[rank])
|
||||
|
||||
proxy = _ext.MyProxyService(group.my_rank, group.nranks, nelem * memory.itemsize, list_conn, list_reg_mem, list_sem)
|
||||
proxy = _ext.MyProxyService(group.my_rank, group.nranks, nelem * memory.itemsize, list_reg_mem, list_sem)
|
||||
|
||||
fifo_device_handle = proxy.fifo_device_handle()
|
||||
|
||||
kernel = MscclppKernel(
|
||||
"proxy", my_rank=group.my_rank, nranks=group.nranks, semaphore_or_channels=semaphores, fifo=fifo_device_handle
|
||||
"proxy",
|
||||
my_rank=group.my_rank,
|
||||
nranks=group.nranks,
|
||||
semaphore_or_channels={rank: Host2DeviceSemaphore(sema) for rank, sema in semaphores.items()},
|
||||
fifo=fifo_device_handle,
|
||||
)
|
||||
proxy.start()
|
||||
group.barrier()
|
||||
@@ -632,7 +636,8 @@ def test_nvls(mpi_group: MpiGroup):
|
||||
mem_handle = nvls_connection.bind_allocated_memory(memory.data.ptr, memory.data.mem.size)
|
||||
|
||||
nvlinks_connections = create_connection(group, "NVLink")
|
||||
semaphores = group.make_semaphore(nvlinks_connections, MemoryDevice2DeviceSemaphore)
|
||||
semaphores = group.make_semaphores(nvlinks_connections)
|
||||
semaphores = {rank: MemoryDevice2DeviceSemaphore(sema) for rank, sema in semaphores.items()}
|
||||
|
||||
kernel = MscclppKernel(
|
||||
"nvls",
|
||||
|
||||
Reference in New Issue
Block a user