mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Integrate MSCCL++ DSL to torch workload (#620)
Provides two integration ways for MSCCL++ DSL. 1. Integrate with customized communication group 2. Integrate with NCCL API Introduce new Python APIs to make it work: ```python mscclpp.compile # compile dsl to json based execution plan mscclpp.ExecutionPlanRegistry.register_plan(plan) # register the compiled plan to executionPlanRegistery mscclpp.ExecutionPlanRegistry.set_selector(selector) # set the selector, the selector will return the best execution plan based on collection, message size, world size.... ``` Fix #556 --------- Co-authored-by: Caio Rocha <caiorocha@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
48
python/csrc/memory_channel_py.cpp
Normal file
48
python/csrc/memory_channel_py.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_memory_channel(nb::module_& m) {
|
||||
nb::class_<BaseMemoryChannel>(m, "BaseMemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<std::shared_ptr<MemoryDevice2DeviceSemaphore>>(), nb::arg("semaphore"))
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def("device_handle", &BaseMemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "BaseMemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &BaseMemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_prop_ro("raw", [](const BaseMemoryChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<MemoryChannel>(m, "MemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, RegisteredMemory src) { new (memoryChannel) MemoryChannel(semaphore, dst, src); })
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
|
||||
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
|
||||
})
|
||||
.def("device_handle", &MemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("dst_", &MemoryChannel::DeviceHandle::dst_)
|
||||
.def_rw("src_", &MemoryChannel::DeviceHandle::src_)
|
||||
.def_rw("packetBuffer_", &MemoryChannel::DeviceHandle::packetBuffer_)
|
||||
.def_prop_ro("raw", [](const MemoryChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
Reference in New Issue
Block a user