This commit is contained in:
Crutcher Dunnavant
2023-04-07 12:08:32 -07:00
parent d014693288
commit 44a8a539ad
3 changed files with 51 additions and 30 deletions

View File

@@ -94,16 +94,21 @@ struct _Comm {
int _world_size;
mscclppComm_t _handle;
bool _is_open;
bool _proxies_running;
public:
_Comm(int rank, int world_size, mscclppComm_t handle)
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true) {}
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true), _proxies_running(false) {}
~_Comm() { close(); }
// Close should be safe to call on a closed handle.
void close() {
if (_is_open) {
if (_proxies_running) {
mscclppProxyStop(_handle);
_proxies_running = false;
}
checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel");
_handle = NULL;
_is_open = false;
@@ -138,20 +143,15 @@ NB_MODULE(_py_mscclpp, m) {
m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES;
m.def("_bind_log_handler", [](nb::callable cb) {
m.def("_bind_log_handler", [](nb::callable cb) -> void {
_log_callback = nb::borrow<nb::callable>(cb);
mscclppSetLogHandler(_LogHandler);
});
m.def("_release_log_handler", []() {
m.def("_release_log_handler", []() -> void {
_log_callback.reset();
mscclppSetLogHandler(mscclppDefaultLogHandler);
});
m.def("_setup", []() {
int device;
cudaGetDevice(&device);
});
nb::enum_<mscclppTransport_t>(m, "TransportType")
.value("P2P", mscclppTransport_t::mscclppTransportP2P)
.value("SHM", mscclppTransport_t::mscclppTransportSHM)
@@ -161,7 +161,7 @@ NB_MODULE(_py_mscclpp, m) {
.def_ro_static("__doc__", &DOC_MscclppUniqueId)
.def_static(
"from_context",
[]() {
[]() -> mscclppUniqueId {
mscclppUniqueId uniqueId;
return maybe(
mscclppGetUniqueId(&uniqueId),
@@ -171,7 +171,7 @@ NB_MODULE(_py_mscclpp, m) {
nb::call_guard<nb::gil_scoped_release>())
.def_static(
"from_bytes",
[](nb::bytes source) {
[](nb::bytes source) -> mscclppUniqueId {
if (source.size() != MSCCLPP_UNIQUE_ID_BYTES) {
throw std::invalid_argument(string_format(
"Requires exactly %d bytes; found %d",
@@ -192,7 +192,7 @@ NB_MODULE(_py_mscclpp, m) {
.def_ro_static("__doc__", &DOC__Comm)
.def_static(
"init_rank_from_address",
[](const std::string& address, int rank, int world_size) {
[](const std::string& address, int rank, int world_size) -> _Comm* {
mscclppComm_t handle;
checkResult(
mscclppCommInitRank(&handle, world_size, address.c_str(), rank),
@@ -210,7 +210,7 @@ NB_MODULE(_py_mscclpp, m) {
"Initialize comms given an IP address, rank, and world_size")
.def_static(
"init_rank_from_id",
[](const mscclppUniqueId& id, int rank, int world_size) {
[](const mscclppUniqueId& id, int rank, int world_size) -> _Comm* {
mscclppComm_t handle;
checkResult(
mscclppCommInitRankFromId(&handle, world_size, id, rank),
@@ -228,25 +228,28 @@ NB_MODULE(_py_mscclpp, m) {
"Initialize comms given u UniqueID, rank, and world_size")
.def(
"opened",
[](_Comm& comm) { return comm._is_open; },
[](_Comm& comm) -> bool { return comm._is_open; },
"Is this comm object opened?")
.def(
"closed",
[](_Comm& comm) { return !comm._is_open; },
[](_Comm& comm) -> bool { return !comm._is_open; },
"Is this comm object closed?")
.def_ro("rank", &_Comm::_rank)
.def_ro("world_size", &_Comm::_world_size)
.def(
"connect",
[](_Comm& self,
[](_Comm& comm,
int remote_rank,
int tag,
uint64_t local_buff,
uint64_t buff_size,
mscclppTransport_t transport_type) -> void {
if (comm._proxies_running) {
throw std::invalid_argument("Proxy Threads Already Running");
}
RETRY(
mscclppConnect(
self._handle,
comm._handle,
remote_rank,
tag,
reinterpret_cast<void*>(local_buff),
@@ -273,24 +276,29 @@ NB_MODULE(_py_mscclpp, m) {
nb::call_guard<nb::gil_scoped_release>(),
"Run connection setup for MSCCLPP.")
.def(
"launch_proxy",
[](_Comm& comm) {
"launch_proxies",
[](_Comm& comm) -> void {
comm.check_open();
return maybe(
if (comm._proxies_running) {
throw std::invalid_argument("Proxy Threads Already Running");
}
checkResult(
mscclppProxyLaunch(comm._handle),
true,
"Failed to launch MSCCLPP proxy");
comm._proxies_running = true;
},
nb::call_guard<nb::gil_scoped_release>(),
"Start the MSCCLPP proxy.")
.def(
"stop_proxy",
[](_Comm& comm) {
"stop_proxies",
[](_Comm& comm) -> void {
comm.check_open();
return maybe(
mscclppProxyStop(comm._handle),
true,
"Failed to stop MSCCLPP proxy");
if (comm._proxies_running) {
checkResult(
mscclppProxyStop(comm._handle),
"Failed to stop MSCCLPP proxy");
comm._proxies_running = false;
}
},
nb::call_guard<nb::gil_scoped_release>(),
"Start the MSCCLPP proxy.")
@@ -318,7 +326,7 @@ NB_MODULE(_py_mscclpp, m) {
"all-gather ints over the bootstrap connection.")
.def(
"all_gather_bytes",
[](_Comm& comm, nb::bytes& item) {
[](_Comm& comm, nb::bytes& item) -> std::vector<nb::bytes> {
// First, all-gather the sizes of all bytes.
std::vector<size_t> sizes(comm._world_size);
sizes[comm._rank] = item.size();

View File

@@ -10,8 +10,6 @@ logger = logging.getLogger(__file__)
from . import _py_mscclpp
_py_mscclpp._setup()
__all__ = (
"Comm",
"MscclppUniqueId",
@@ -156,7 +154,12 @@ class Comm:
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
def connect(
self, remote_rank: int, tag: int, data_ptr, data_size, transport: int
self,
remote_rank: int,
tag: int,
data_ptr,
data_size: int,
transport: int,
) -> None:
self._comm.connect(
remote_rank,
@@ -168,3 +171,9 @@ class Comm:
def connection_setup(self) -> None:
self._comm.connection_setup()
def launch_proxies(self) -> None:
self._comm.launch_proxies()
def stop_proxies(self) -> None:
self._comm.stop_proxies()

View File

@@ -99,6 +99,10 @@ def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm):
comm.connection_setup()
comm.launch_proxies()
comm.stop_proxies()
def main():
p = argparse.ArgumentParser()