register buffers

This commit is contained in:
Crutcher Dunnavant
2023-04-07 19:11:50 -07:00
parent cc8c30f958
commit 34464b40bb
5 changed files with 137 additions and 42 deletions

View File

@@ -1,10 +1,9 @@
#include <cuda_runtime.h>
#include <mscclpp.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cstdio>
#include <cstring>
@@ -71,14 +70,14 @@ void checkResult(
}
}
#define RETRY(C, ...) \
{ \
mscclppResult_t res; \
do { \
res = (C); \
} while (res == mscclppInProgress); \
checkResult(res, __VA_ARGS__); \
}
#define RETRY(C, ...) \
{ \
mscclppResult_t res; \
do { \
res = (C); \
} while (res == mscclppInProgress); \
checkResult(res, __VA_ARGS__); \
}
// Maybe return the value, maybe throw an exception.
template <typename Val, typename... Args>
@@ -98,7 +97,11 @@ struct _Comm {
public:
_Comm(int rank, int world_size, mscclppComm_t handle)
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true), _proxies_running(false) {}
: _rank(rank),
_world_size(world_size),
_handle(handle),
_is_open(true),
_proxies_running(false) {}
~_Comm() { close(); }
@@ -106,8 +109,8 @@ struct _Comm {
void close() {
if (_is_open) {
if (_proxies_running) {
mscclppProxyStop(_handle);
_proxies_running = false;
mscclppProxyStop(_handle);
_proxies_running = false;
}
checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel");
_handle = NULL;
@@ -124,6 +127,21 @@ struct _Comm {
}
};
struct _P2PHandle {
struct mscclppRegisteredMemoryP2P _rmP2P;
struct mscclppIbMr _ibmr;
_P2PHandle() : _rmP2P({0}), _ibmr({0}) {}
_P2PHandle(const mscclppRegisteredMemoryP2P &p2p): _ibmr({0}) {
_rmP2P = p2p;
if (_rmP2P.IbMr != nullptr) {
_ibmr = *_rmP2P.IbMr;
_rmP2P.IbMr = &_ibmr;
}
}
};
nb::callable _log_callback;
void _LogHandler(const char* msg) {
@@ -138,6 +156,8 @@ static const std::string DOC_MscclppUniqueId =
static const std::string DOC__Comm = "MSCCLPP Communications Handle";
static const std::string DOC__P2PHandle = "MSCCLPP P2P MR Handle";
NB_MODULE(_py_mscclpp, m) {
m.doc() = "Python bindings for MSCCLPP: which is not NCCL";
@@ -188,6 +208,9 @@ NB_MODULE(_py_mscclpp, m) {
return nb::bytes(id.internal, sizeof(id.internal));
});
nb::class_<_P2PHandle>(m, "_P2PHandle")
.def_ro_static("__doc__", &DOC__P2PHandle);
nb::class_<_Comm>(m, "_Comm")
.def_ro_static("__doc__", &DOC__Comm)
.def_static(
@@ -236,6 +259,29 @@ NB_MODULE(_py_mscclpp, m) {
"Is this comm object closed?")
.def_ro("rank", &_Comm::_rank)
.def_ro("world_size", &_Comm::_world_size)
.def(
"register_buffer",
[](_Comm& comm, uint64_t local_buff, uint64_t buff_size) -> std::vector<_P2PHandle> {
comm.check_open();
mscclppRegisteredMemory regMem;
checkResult(
mscclppRegisterBuffer(
comm._handle,
reinterpret_cast<void*>(local_buff),
buff_size,
&regMem),
"Registering buffer failed");
std::vector<_P2PHandle> handles;
for (const auto &p2p : regMem.p2p) {
handles.push_back(_P2PHandle(p2p));
}
return handles;
},
"local_buf"_a,
"buff_size"_a,
nb::call_guard<nb::gil_scoped_release>(),
"Register a buffer for P2P transfers.")
.def(
"connect",
[](_Comm& comm,
@@ -244,9 +290,7 @@ NB_MODULE(_py_mscclpp, m) {
uint64_t local_buff,
uint64_t buff_size,
mscclppTransport_t transport_type) -> void {
if (comm._proxies_running) {
throw std::invalid_argument("Proxy Threads Already Running");
}
comm.check_open();
RETRY(
mscclppConnect(
comm._handle,
@@ -270,8 +314,9 @@ NB_MODULE(_py_mscclpp, m) {
"connection_setup",
[](_Comm& comm) -> void {
comm.check_open();
RETRY(mscclppConnectionSetup(comm._handle),
"Failed to setup MSCCLPP connection");
RETRY(
mscclppConnectionSetup(comm._handle),
"Failed to setup MSCCLPP connection");
},
nb::call_guard<nb::gil_scoped_release>(),
"Run connection setup for MSCCLPP.")
@@ -304,15 +349,6 @@ NB_MODULE(_py_mscclpp, m) {
"Start the MSCCLPP proxy.")
.def("close", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
.def("__del__", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
.def(
"connection_setup",
[](_Comm& comm) -> void {
comm.check_open();
checkResult(
mscclppConnectionSetup(comm._handle),
"Connection Setup Failed");
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"bootstrap_all_gather_int",
[](_Comm& comm, int val) -> std::vector<int> {