Files
mscclpp/python/csrc/switch_channel_py.cpp
Changho Hwang 105239fc6c Use GpuIpcMem for NVLS connections (#719)
* Now `NvlsConnection` internally reuses `GpuIpcMem` for multicast
memory handling.
* Removed unnecessary barriers from `connectNvlsCollective()` (CUDA API
handles this automatically).
* Updated `GpuIpcMem::map()` and `GpuIpcMem::mapMulticast()` to return a
shared pointer with custom deleter for unmapping, which prevents misuse
of raw pointers and reduces states to be stored in the `GpuIpcMem`
instance.
* Now for `RuntimeIpc` type handles, for consistency with other types,
`cudaIpcOpenMemHandle` will be called in `GpuIpcMem::map()` instead of
the ctor of `GpuIpcMem`.

---------

Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Binyang2014 <9415966+Binyang2014@users.noreply.github.com>
2026-01-15 13:16:04 +08:00

37 lines
1.4 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/operators.h>
#include <nanobind/stl/array.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <mscclpp/core.hpp>
#include <mscclpp/switch_channel.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_nvls(nb::module_& m) {
nb::class_<SwitchChannel>(m, "SwitchChannel")
.def("get_device_ptr", [](SwitchChannel* self) { return (uintptr_t)self->getDevicePtr(); })
.def("device_handle", &SwitchChannel::deviceHandle);
nb::class_<SwitchChannel::DeviceHandle>(m, "DeviceHandle")
.def(nb::init<>())
.def_rw("device_ptr", &SwitchChannel::DeviceHandle::devicePtr)
.def_rw("mc_ptr", &SwitchChannel::DeviceHandle::mcPtr)
.def_rw("size", &SwitchChannel::DeviceHandle::bufferSize)
.def_prop_ro("raw", [](const SwitchChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<NvlsConnection>(m, "NvlsConnection")
.def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("device_ptr"), nb::arg("size"));
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("all_ranks"),
nb::arg("buffer_size"));
}