mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
register buffers
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <mscclpp.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
@@ -71,14 +70,14 @@ void checkResult(
|
||||
}
|
||||
}
|
||||
|
||||
#define RETRY(C, ...) \
|
||||
{ \
|
||||
mscclppResult_t res; \
|
||||
do { \
|
||||
res = (C); \
|
||||
} while (res == mscclppInProgress); \
|
||||
checkResult(res, __VA_ARGS__); \
|
||||
}
|
||||
#define RETRY(C, ...) \
|
||||
{ \
|
||||
mscclppResult_t res; \
|
||||
do { \
|
||||
res = (C); \
|
||||
} while (res == mscclppInProgress); \
|
||||
checkResult(res, __VA_ARGS__); \
|
||||
}
|
||||
|
||||
// Maybe return the value, maybe throw an exception.
|
||||
template <typename Val, typename... Args>
|
||||
@@ -98,7 +97,11 @@ struct _Comm {
|
||||
|
||||
public:
|
||||
_Comm(int rank, int world_size, mscclppComm_t handle)
|
||||
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true), _proxies_running(false) {}
|
||||
: _rank(rank),
|
||||
_world_size(world_size),
|
||||
_handle(handle),
|
||||
_is_open(true),
|
||||
_proxies_running(false) {}
|
||||
|
||||
~_Comm() { close(); }
|
||||
|
||||
@@ -106,8 +109,8 @@ struct _Comm {
|
||||
void close() {
|
||||
if (_is_open) {
|
||||
if (_proxies_running) {
|
||||
mscclppProxyStop(_handle);
|
||||
_proxies_running = false;
|
||||
mscclppProxyStop(_handle);
|
||||
_proxies_running = false;
|
||||
}
|
||||
checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel");
|
||||
_handle = NULL;
|
||||
@@ -124,6 +127,21 @@ struct _Comm {
|
||||
}
|
||||
};
|
||||
|
||||
struct _P2PHandle {
|
||||
struct mscclppRegisteredMemoryP2P _rmP2P;
|
||||
struct mscclppIbMr _ibmr;
|
||||
|
||||
_P2PHandle() : _rmP2P({0}), _ibmr({0}) {}
|
||||
|
||||
_P2PHandle(const mscclppRegisteredMemoryP2P &p2p): _ibmr({0}) {
|
||||
_rmP2P = p2p;
|
||||
if (_rmP2P.IbMr != nullptr) {
|
||||
_ibmr = *_rmP2P.IbMr;
|
||||
_rmP2P.IbMr = &_ibmr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
nb::callable _log_callback;
|
||||
|
||||
void _LogHandler(const char* msg) {
|
||||
@@ -138,6 +156,8 @@ static const std::string DOC_MscclppUniqueId =
|
||||
|
||||
static const std::string DOC__Comm = "MSCCLPP Communications Handle";
|
||||
|
||||
static const std::string DOC__P2PHandle = "MSCCLPP P2P MR Handle";
|
||||
|
||||
NB_MODULE(_py_mscclpp, m) {
|
||||
m.doc() = "Python bindings for MSCCLPP: which is not NCCL";
|
||||
|
||||
@@ -188,6 +208,9 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
return nb::bytes(id.internal, sizeof(id.internal));
|
||||
});
|
||||
|
||||
nb::class_<_P2PHandle>(m, "_P2PHandle")
|
||||
.def_ro_static("__doc__", &DOC__P2PHandle);
|
||||
|
||||
nb::class_<_Comm>(m, "_Comm")
|
||||
.def_ro_static("__doc__", &DOC__Comm)
|
||||
.def_static(
|
||||
@@ -236,6 +259,29 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"Is this comm object closed?")
|
||||
.def_ro("rank", &_Comm::_rank)
|
||||
.def_ro("world_size", &_Comm::_world_size)
|
||||
.def(
|
||||
"register_buffer",
|
||||
[](_Comm& comm, uint64_t local_buff, uint64_t buff_size) -> std::vector<_P2PHandle> {
|
||||
comm.check_open();
|
||||
mscclppRegisteredMemory regMem;
|
||||
checkResult(
|
||||
mscclppRegisterBuffer(
|
||||
comm._handle,
|
||||
reinterpret_cast<void*>(local_buff),
|
||||
buff_size,
|
||||
®Mem),
|
||||
"Registering buffer failed");
|
||||
|
||||
std::vector<_P2PHandle> handles;
|
||||
for (const auto &p2p : regMem.p2p) {
|
||||
handles.push_back(_P2PHandle(p2p));
|
||||
}
|
||||
return handles;
|
||||
},
|
||||
"local_buf"_a,
|
||||
"buff_size"_a,
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Register a buffer for P2P transfers.")
|
||||
.def(
|
||||
"connect",
|
||||
[](_Comm& comm,
|
||||
@@ -244,9 +290,7 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
uint64_t local_buff,
|
||||
uint64_t buff_size,
|
||||
mscclppTransport_t transport_type) -> void {
|
||||
if (comm._proxies_running) {
|
||||
throw std::invalid_argument("Proxy Threads Already Running");
|
||||
}
|
||||
comm.check_open();
|
||||
RETRY(
|
||||
mscclppConnect(
|
||||
comm._handle,
|
||||
@@ -270,8 +314,9 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"connection_setup",
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
RETRY(mscclppConnectionSetup(comm._handle),
|
||||
"Failed to setup MSCCLPP connection");
|
||||
RETRY(
|
||||
mscclppConnectionSetup(comm._handle),
|
||||
"Failed to setup MSCCLPP connection");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Run connection setup for MSCCLPP.")
|
||||
@@ -304,15 +349,6 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"Start the MSCCLPP proxy.")
|
||||
.def("close", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("__del__", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"connection_setup",
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
checkResult(
|
||||
mscclppConnectionSetup(comm._handle),
|
||||
"Connection Setup Failed");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"bootstrap_all_gather_int",
|
||||
[](_Comm& comm, int val) -> std::vector<int> {
|
||||
|
||||
@@ -18,6 +18,7 @@ __all__ = (
|
||||
)
|
||||
|
||||
_Comm = _py_mscclpp._Comm
|
||||
_P2PHandle = _py_mscclpp._P2PHandle
|
||||
TransportType = _py_mscclpp.TransportType
|
||||
|
||||
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
|
||||
@@ -46,6 +47,7 @@ MSCCLPP_LOG_LEVELS: set[str] = {
|
||||
"TRACE",
|
||||
}
|
||||
|
||||
|
||||
def _setup_logging(level: str = "INFO"):
|
||||
"""Setup log hooks for the C library."""
|
||||
level = level.upper()
|
||||
@@ -69,11 +71,11 @@ class Comm:
|
||||
|
||||
@staticmethod
|
||||
def init_rank_from_address(
|
||||
address: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
address: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
):
|
||||
"""Initialize a Comm from an address.
|
||||
|
||||
@@ -154,12 +156,12 @@ class Comm:
|
||||
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
|
||||
|
||||
def connect(
|
||||
self,
|
||||
remote_rank: int,
|
||||
tag: int,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
transport: int,
|
||||
self,
|
||||
remote_rank: int,
|
||||
tag: int,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
transport: int,
|
||||
) -> None:
|
||||
self._comm.connect(
|
||||
remote_rank,
|
||||
@@ -177,3 +179,26 @@ class Comm:
|
||||
|
||||
def stop_proxies(self) -> None:
|
||||
self._comm.stop_proxies()
|
||||
|
||||
def register_buffer(
|
||||
self,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
) -> list[_P2PHandle]:
|
||||
return [
|
||||
P2PHandle(self, h) for h in self._comm.register_buffer(
|
||||
data_ptr,
|
||||
data_size,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class P2PHandle:
|
||||
_comm: Comm
|
||||
_handle: _P2PHandle
|
||||
|
||||
def __init__(self,
|
||||
comm: Comm,
|
||||
handle: _P2PHandle):
|
||||
self._comm = comm
|
||||
self._handle = handle
|
||||
|
||||
@@ -79,6 +79,8 @@ class CommsTest(unittest.TestCase):
|
||||
if errors:
|
||||
parts = []
|
||||
for rank, content in errors:
|
||||
parts.append(f"[rank {rank}]: " + content.decode('utf-8', errors='ignore'))
|
||||
parts.append(
|
||||
f"[rank {rank}]: " + content.decode("utf-8", errors="ignore")
|
||||
)
|
||||
|
||||
raise AssertionError("\n\n".join(parts))
|
||||
|
||||
@@ -73,7 +73,7 @@ def _test_bootstrap_allgather_pickle(options: argparse.Namespace, comm: mscclpp.
|
||||
comm.connection_setup()
|
||||
|
||||
|
||||
def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
def _test_rm(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
rank = options.rank
|
||||
|
||||
buf = torch.zeros([options.world_size], dtype=torch.int64)
|
||||
@@ -95,6 +95,12 @@ def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
mscclpp.TransportType.P2P,
|
||||
)
|
||||
|
||||
handles = comm.register_buffer(buf.data_ptr(), buf.element_size() * buf.numel())
|
||||
hamcrest.assert_that(
|
||||
handles,
|
||||
hamcrest.has_length(options.world_size - 1),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
comm.connection_setup()
|
||||
@@ -103,7 +109,6 @@ def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
comm.stop_proxies()
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--rank", type=int, required=True)
|
||||
@@ -131,7 +136,7 @@ def main():
|
||||
_test_bootstrap_allgather_bytes(options, comm)
|
||||
_test_bootstrap_allgather_json(options, comm)
|
||||
_test_bootstrap_allgather_pickle(options, comm)
|
||||
_test_p2p_connect(options, comm)
|
||||
_test_rm(options, comm)
|
||||
finally:
|
||||
comm.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user