registeredmemory wip

This commit is contained in:
Crutcher Dunnavant
2023-04-11 13:24:19 -07:00
parent 19962a8002
commit b93cfa3ca4
3 changed files with 201 additions and 92 deletions

View File

@@ -128,16 +128,24 @@ struct _Comm {
};
struct _P2PHandle {
struct mscclppRegisteredMemoryP2P _rmP2P;
struct mscclppIbMr _ibmr;
struct mscclppRegisteredMemoryP2P _rm;
struct mscclppIbMr _local_ibmr;
_P2PHandle() : _rmP2P({0}), _ibmr({0}) {}
_P2PHandle() : _rm({0}), _local_ibmr({0}) {}
_P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _ibmr({0}) {
_rmP2P = p2p;
if (_rmP2P.IbMr != nullptr) {
_ibmr = *_rmP2P.IbMr;
_rmP2P.IbMr = &_ibmr;
_P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _local_ibmr({0}) {
_rm = p2p;
if (_rm.IbMr != nullptr) {
_local_ibmr = *_rm.IbMr;
_rm.IbMr = &_local_ibmr;
}
}
mscclppTransport_t transport() const {
if (_rm.IbMr != nullptr) {
return mscclppTransport_t::mscclppTransportIB;
} else {
return mscclppTransport_t::mscclppTransportP2P;
}
}
};
@@ -209,7 +217,60 @@ NB_MODULE(_py_mscclpp, m) {
});
nb::class_<_P2PHandle>(m, "_P2PHandle")
.def_ro_static("__doc__", &DOC__P2PHandle);
.def_ro_static("__doc__", &DOC__P2PHandle)
.def(
"transport",
&_P2PHandle::transport,
"Get the transport type of the handle")
.def(
"data_ptr",
[](const _P2PHandle& self) -> uint64_t {
if (self.transport() == mscclppTransport_t::mscclppTransportP2P) {
return reinterpret_cast<uint64_t>(self._rm.remoteBuff);
}
throw std::invalid_argument(
"IB transport does not have a local data ptr");
},
"Get the local data pointer, only for P2P handles");
nb::class_<mscclppRegisteredMemory>(m, "_RegisteredMemory")
.def(
"handles",
[](const mscclppRegisteredMemory& self) -> std::vector<_P2PHandle> {
std::vector<_P2PHandle> handles;
for (const auto& p2p : self.p2p) {
handles.push_back(_P2PHandle(p2p));
}
return handles;
},
"Get the P2P handle for this memory")
.def(
"write_all",
[](const mscclppRegisteredMemory& self,
const _Comm& comm,
uint64_t src_data,
size_t size,
uint32_t src_offset = 0,
uint32_t dst_offset = 0,
int64_t stream = 0) -> void {
checkResult(
mscclppRegisteredBufferWrite(
comm._handle,
const_cast<mscclppRegisteredMemory*>(&self),
reinterpret_cast<void*>(src_data),
size,
src_offset,
dst_offset,
stream),
"Failed to write to registered memory");
},
"comm"_a,
"src_data"_a,
"size"_a,
"src_offset"_a = 0,
"dst_offset"_a = 0,
"stream"_a = 0,
"Write to all bound targets in the buffer");
nb::class_<_Comm>(m, "_Comm")
.def_ro_static("__doc__", &DOC__Comm)
@@ -248,58 +309,54 @@ NB_MODULE(_py_mscclpp, m) {
"id"_a,
"rank"_a,
"world_size"_a,
"Initialize comms given u UniqueID, rank, and world_size")
"Initialize comms given UniqueID, rank, and world_size")
.def(
"opened",
[](_Comm& comm) -> bool { return comm._is_open; },
[](_Comm& self) -> bool { return self._is_open; },
"Is this comm object opened?")
.def(
"closed",
[](_Comm& comm) -> bool { return !comm._is_open; },
[](_Comm& self) -> bool { return !self._is_open; },
"Is this comm object closed?")
.def_ro("rank", &_Comm::_rank)
.def_ro("world_size", &_Comm::_world_size)
.def(
"register_buffer",
[](_Comm& comm,
uint64_t local_buff,
uint64_t buff_size) -> std::vector<_P2PHandle> {
comm.check_open();
[](_Comm& self,
uint64_t data_ptr,
uint64_t size) -> mscclppRegisteredMemory {
self.check_open();
mscclppRegisteredMemory regMem;
checkResult(
mscclppRegisterBuffer(
comm._handle,
reinterpret_cast<void*>(local_buff),
buff_size,
self._handle,
reinterpret_cast<void*>(data_ptr),
size,
&regMem),
"Registering buffer failed");
std::vector<_P2PHandle> handles;
for (const auto& p2p : regMem.p2p) {
handles.push_back(_P2PHandle(p2p));
}
return handles;
return regMem;
;
},
"local_buf"_a,
"buff_size"_a,
"data_ptr"_a,
"size"_a,
nb::call_guard<nb::gil_scoped_release>(),
"Register a buffer for P2P transfers.")
.def(
"connect",
[](_Comm& comm,
[](_Comm& self,
int remote_rank,
int tag,
uint64_t local_buff,
uint64_t buff_size,
uint64_t data_ptr,
uint64_t size,
mscclppTransport_t transport_type) -> void {
comm.check_open();
self.check_open();
RETRY(
mscclppConnect(
comm._handle,
self._handle,
remote_rank,
tag,
reinterpret_cast<void*>(local_buff),
buff_size,
reinterpret_cast<void*>(data_ptr),
size,
transport_type,
NULL // ibDev
),
@@ -307,44 +364,44 @@ NB_MODULE(_py_mscclpp, m) {
},
"remote_rank"_a,
"tag"_a,
"local_buf"_a,
"buff_size"_a,
"data_ptr"_a,
"size"_a,
"transport_type"_a,
nb::call_guard<nb::gil_scoped_release>(),
"Attach a local buffer to a remote connection.")
.def(
"connection_setup",
[](_Comm& comm) -> void {
comm.check_open();
[](_Comm& self) -> void {
self.check_open();
RETRY(
mscclppConnectionSetup(comm._handle),
mscclppConnectionSetup(self._handle),
"Failed to setup MSCCLPP connection");
},
nb::call_guard<nb::gil_scoped_release>(),
"Run connection setup for MSCCLPP.")
.def(
"launch_proxies",
[](_Comm& comm) -> void {
comm.check_open();
if (comm._proxies_running) {
[](_Comm& self) -> void {
self.check_open();
if (self._proxies_running) {
throw std::invalid_argument("Proxy Threads Already Running");
}
checkResult(
mscclppProxyLaunch(comm._handle),
mscclppProxyLaunch(self._handle),
"Failed to launch MSCCLPP proxy");
comm._proxies_running = true;
self._proxies_running = true;
},
nb::call_guard<nb::gil_scoped_release>(),
"Start the MSCCLPP proxy.")
.def(
"stop_proxies",
[](_Comm& comm) -> void {
comm.check_open();
if (comm._proxies_running) {
[](_Comm& self) -> void {
self.check_open();
if (self._proxies_running) {
checkResult(
mscclppProxyStop(comm._handle),
mscclppProxyStop(self._handle),
"Failed to stop MSCCLPP proxy");
comm._proxies_running = false;
self._proxies_running = false;
}
},
nb::call_guard<nb::gil_scoped_release>(),
@@ -353,10 +410,10 @@ NB_MODULE(_py_mscclpp, m) {
.def("__del__", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
.def(
"bootstrap_all_gather_int",
[](_Comm& comm, int val) -> std::vector<int> {
std::vector<int> buf(comm._world_size);
buf[comm._rank] = val;
mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int));
[](_Comm& self, int val) -> std::vector<int> {
std::vector<int> buf(self._world_size);
buf[self._rank] = val;
mscclppBootstrapAllGather(self._handle, buf.data(), sizeof(int));
return buf;
},
nb::call_guard<nb::gil_scoped_release>(),
@@ -364,13 +421,13 @@ NB_MODULE(_py_mscclpp, m) {
"all-gather ints over the bootstrap connection.")
.def(
"all_gather_bytes",
[](_Comm& comm, nb::bytes& item) -> std::vector<nb::bytes> {
[](_Comm& self, nb::bytes& item) -> std::vector<nb::bytes> {
// First, all-gather the sizes of all bytes.
std::vector<size_t> sizes(comm._world_size);
sizes[comm._rank] = item.size();
std::vector<size_t> sizes(self._world_size);
sizes[self._rank] = item.size();
checkResult(
mscclppBootstrapAllGather(
comm._handle, sizes.data(), sizeof(size_t)),
self._handle, sizes.data(), sizeof(size_t)),
"bootstrapAllGather failed.");
// Next, find the largest message to send.
@@ -378,21 +435,21 @@ NB_MODULE(_py_mscclpp, m) {
// Allocate an all-gather buffer large enough for max * world_size.
std::shared_ptr<char[]> data_buf(
new char[max_size * comm._world_size]);
new char[max_size * self._world_size]);
// Copy the local item into the buffer.
std::memcpy(
&data_buf[comm._rank * max_size], item.c_str(), item.size());
&data_buf[self._rank * max_size], item.c_str(), item.size());
// all-gather the data buffer.
checkResult(
mscclppBootstrapAllGather(
comm._handle, data_buf.get(), max_size),
self._handle, data_buf.get(), max_size),
"bootstrapAllGather failed.");
// Build a response vector.
std::vector<nb::bytes> ret;
for (int i = 0; i < comm._world_size; ++i) {
for (int i = 0; i < self._world_size; ++i) {
// Copy out the relevant range of each item.
ret.push_back(nb::bytes(&data_buf[i * max_size], sizes[i]));
}

View File

@@ -10,14 +10,8 @@ logger = logging.getLogger(__file__)
from . import _py_mscclpp
__all__ = (
"Comm",
"MscclppUniqueId",
"MSCCLPP_UNIQUE_ID_BYTES",
"TransportType",
)
_Comm = _py_mscclpp._Comm
_RegisteredMemory = _py_mscclpp._RegisteredMemory
_P2PHandle = _py_mscclpp._P2PHandle
TransportType = _py_mscclpp.TransportType
@@ -67,7 +61,7 @@ _setup_logging()
class Comm:
"""Comm object; represents a mscclpp connection."""
_comm: _Comm
_c_comm: _Comm
@staticmethod
def init_rank_from_address(
@@ -97,16 +91,16 @@ class Comm:
def __init__(self, *, _comm: _Comm):
"""Construct a Comm object wrapping an internal _Comm handle."""
self._comm = _comm
self._c_comm = _comm
def __del__(self) -> None:
self.close()
def close(self) -> None:
"""Close the connection."""
if self._comm:
self._comm.close()
self._comm = None
if self._c_comm:
self._c_comm.close()
self._c_comm = None
@property
def rank(self) -> int:
@@ -114,7 +108,7 @@ class Comm:
Assumes the Comm is open.
"""
return self._comm.rank
return self._c_comm.rank
@property
def world_size(self) -> int:
@@ -122,11 +116,11 @@ class Comm:
Assumes the Comm is open.
"""
return self._comm.world_size
return self._c_comm.world_size
def bootstrap_all_gather_int(self, val: int) -> list[int]:
"""AllGather an int value through the bootstrap interface."""
return self._comm.bootstrap_all_gather_int(val)
return self._c_comm.bootstrap_all_gather_int(val)
def all_gather_bytes(self, item: bytes) -> list[bytes]:
"""AllGather bytes (of different sizes) through the bootstrap interface.
@@ -134,7 +128,7 @@ class Comm:
:param item: the bytes object for this rank.
:return: a list of bytes objects; the ret[rank] object will be a new copy.
"""
return self._comm.all_gather_bytes(item)
return self._c_comm.all_gather_bytes(item)
def all_gather_json(self, item: Any) -> list[Any]:
"""AllGather JSON objects through the bootstrap interface.
@@ -163,7 +157,7 @@ class Comm:
data_size: int,
transport: int,
) -> None:
self._comm.connect(
self._c_comm.connect(
remote_rank,
tag,
data_ptr,
@@ -172,32 +166,79 @@ class Comm:
)
def connection_setup(self) -> None:
self._comm.connection_setup()
self._c_comm.connection_setup()
def launch_proxies(self) -> None:
self._comm.launch_proxies()
self._c_comm.launch_proxies()
def stop_proxies(self) -> None:
self._comm.stop_proxies()
self._c_comm.stop_proxies()
def register_buffer(
self,
data_ptr,
data_size: int,
) -> list[_P2PHandle]:
data_ptr: int,
size: int,
) -> "RegisteredMemory":
return RegisteredMemory(
comm=self._c_comm,
rm=self._c_comm.register_buffer(
data_ptr=data_ptr,
size=size,
),
)
class RegisteredMemory:
_comm: Comm
_c_rm: _RegisteredMemory
def __init__(
self,
*,
comm: Comm,
rm: _RegisteredMemory,
):
self._comm = comm
self._c_rm = rm
def handles(self) -> list["P2PHandle"]:
return [
P2PHandle(self, h)
for h in self._comm.register_buffer(
data_ptr,
data_size,
P2PHandle(
comm=self._comm,
handle=h,
)
for h in self._c_rm.handles()
]
def _write(
self,
src_ptr: int,
size: int,
*,
src_offset: int = 0,
dst_offset: int = 0,
stream: int = 0,
) -> None:
self._c_rm.write_all(
comm=self._comm._c_comm,
src_data=src_ptr,
size=size,
src_offset=src_offset,
dst_offset=dst_offset,
stream=stream,
)
class P2PHandle:
_comm: Comm
_handle: _P2PHandle
_c_handle: _P2PHandle
def __init__(self, comm: Comm, handle: _P2PHandle):
def __init__(self, *, comm: Comm, handle: _P2PHandle):
self._comm = comm
self._handle = handle
self._c_handle = handle
def transport(self) -> TransportType:
return self._c_handle.transport()
def data_ptr(self) -> int:
return self._c_handle.data_ptr()

View File

@@ -95,11 +95,22 @@ def _test_rm(options: argparse.Namespace, comm: mscclpp.Comm):
mscclpp.TransportType.P2P,
)
handles = comm.register_buffer(buf.data_ptr(), buf.element_size() * buf.numel())
rm = comm.register_buffer(buf.data_ptr(), buf.element_size() * buf.numel())
handles = rm.handles()
hamcrest.assert_that(
handles,
hamcrest.has_length(options.world_size - 1),
)
for handle in handles:
hamcrest.assert_that(
handle.transport(),
hamcrest.equal_to(mscclpp.TransportType.P2P),
)
# assuming P2P ...
hamcrest.assert_that(
handle.data_ptr(),
hamcrest.greater_than(0),
)
torch.cuda.synchronize()