// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. #include #include #include #include namespace nb = nanobind; using namespace mscclpp; void register_memory_channel(nb::module_& m) { nb::class_(m, "CppBaseMemoryChannel") .def(nb::init<>()) .def(nb::init>(), nb::arg("semaphore")) .def(nb::init(), nb::arg("semaphore")) .def("device_handle", &BaseMemoryChannel::deviceHandle); nb::class_(m, "CppBaseMemoryChannelDeviceHandle") .def(nb::init<>()) .def_rw("semaphore_", &BaseMemoryChannel::DeviceHandle::semaphore_) .def_prop_ro("raw", [](const BaseMemoryChannel::DeviceHandle& self) -> nb::bytes { return nb::bytes(reinterpret_cast(&self), sizeof(self)); }); nb::class_(m, "CppMemoryChannel") .def(nb::init<>()) .def( "__init__", [](MemoryChannel* memoryChannel, std::shared_ptr semaphore, RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) { new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast(packet_buffer)); }, nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0) .def( "__init__", [](MemoryChannel* memoryChannel, const Semaphore& semaphore, RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer = 0) { new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast(packet_buffer)); }, nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0) .def("device_handle", &MemoryChannel::deviceHandle); nb::class_(m, "CppMemoryChannelDeviceHandle") .def(nb::init<>()) .def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_) .def_rw("dst_", &MemoryChannel::DeviceHandle::dst_) .def_rw("src_", &MemoryChannel::DeviceHandle::src_) .def_rw("packetBuffer_", &MemoryChannel::DeviceHandle::packetBuffer_) .def_prop_ro("raw", [](const MemoryChannel::DeviceHandle& self) -> nb::bytes { return nb::bytes(reinterpret_cast(&self), sizeof(self)); }); };