mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
types
This commit is contained in:
@@ -94,16 +94,21 @@ struct _Comm {
|
||||
int _world_size;
|
||||
mscclppComm_t _handle;
|
||||
bool _is_open;
|
||||
bool _proxies_running;
|
||||
|
||||
public:
|
||||
_Comm(int rank, int world_size, mscclppComm_t handle)
|
||||
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true) {}
|
||||
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true), _proxies_running(false) {}
|
||||
|
||||
~_Comm() { close(); }
|
||||
|
||||
// Close should be safe to call on a closed handle.
|
||||
void close() {
|
||||
if (_is_open) {
|
||||
if (_proxies_running) {
|
||||
mscclppProxyStop(_handle);
|
||||
_proxies_running = false;
|
||||
}
|
||||
checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel");
|
||||
_handle = NULL;
|
||||
_is_open = false;
|
||||
@@ -138,20 +143,15 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
|
||||
m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES;
|
||||
|
||||
m.def("_bind_log_handler", [](nb::callable cb) {
|
||||
m.def("_bind_log_handler", [](nb::callable cb) -> void {
|
||||
_log_callback = nb::borrow<nb::callable>(cb);
|
||||
mscclppSetLogHandler(_LogHandler);
|
||||
});
|
||||
m.def("_release_log_handler", []() {
|
||||
m.def("_release_log_handler", []() -> void {
|
||||
_log_callback.reset();
|
||||
mscclppSetLogHandler(mscclppDefaultLogHandler);
|
||||
});
|
||||
|
||||
m.def("_setup", []() {
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
});
|
||||
|
||||
nb::enum_<mscclppTransport_t>(m, "TransportType")
|
||||
.value("P2P", mscclppTransport_t::mscclppTransportP2P)
|
||||
.value("SHM", mscclppTransport_t::mscclppTransportSHM)
|
||||
@@ -161,7 +161,7 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
.def_ro_static("__doc__", &DOC_MscclppUniqueId)
|
||||
.def_static(
|
||||
"from_context",
|
||||
[]() {
|
||||
[]() -> mscclppUniqueId {
|
||||
mscclppUniqueId uniqueId;
|
||||
return maybe(
|
||||
mscclppGetUniqueId(&uniqueId),
|
||||
@@ -171,7 +171,7 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def_static(
|
||||
"from_bytes",
|
||||
[](nb::bytes source) {
|
||||
[](nb::bytes source) -> mscclppUniqueId {
|
||||
if (source.size() != MSCCLPP_UNIQUE_ID_BYTES) {
|
||||
throw std::invalid_argument(string_format(
|
||||
"Requires exactly %d bytes; found %d",
|
||||
@@ -192,7 +192,7 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
.def_ro_static("__doc__", &DOC__Comm)
|
||||
.def_static(
|
||||
"init_rank_from_address",
|
||||
[](const std::string& address, int rank, int world_size) {
|
||||
[](const std::string& address, int rank, int world_size) -> _Comm* {
|
||||
mscclppComm_t handle;
|
||||
checkResult(
|
||||
mscclppCommInitRank(&handle, world_size, address.c_str(), rank),
|
||||
@@ -210,7 +210,7 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"Initialize comms given an IP address, rank, and world_size")
|
||||
.def_static(
|
||||
"init_rank_from_id",
|
||||
[](const mscclppUniqueId& id, int rank, int world_size) {
|
||||
[](const mscclppUniqueId& id, int rank, int world_size) -> _Comm* {
|
||||
mscclppComm_t handle;
|
||||
checkResult(
|
||||
mscclppCommInitRankFromId(&handle, world_size, id, rank),
|
||||
@@ -228,25 +228,28 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"Initialize comms given u UniqueID, rank, and world_size")
|
||||
.def(
|
||||
"opened",
|
||||
[](_Comm& comm) { return comm._is_open; },
|
||||
[](_Comm& comm) -> bool { return comm._is_open; },
|
||||
"Is this comm object opened?")
|
||||
.def(
|
||||
"closed",
|
||||
[](_Comm& comm) { return !comm._is_open; },
|
||||
[](_Comm& comm) -> bool { return !comm._is_open; },
|
||||
"Is this comm object closed?")
|
||||
.def_ro("rank", &_Comm::_rank)
|
||||
.def_ro("world_size", &_Comm::_world_size)
|
||||
.def(
|
||||
"connect",
|
||||
[](_Comm& self,
|
||||
[](_Comm& comm,
|
||||
int remote_rank,
|
||||
int tag,
|
||||
uint64_t local_buff,
|
||||
uint64_t buff_size,
|
||||
mscclppTransport_t transport_type) -> void {
|
||||
if (comm._proxies_running) {
|
||||
throw std::invalid_argument("Proxy Threads Already Running");
|
||||
}
|
||||
RETRY(
|
||||
mscclppConnect(
|
||||
self._handle,
|
||||
comm._handle,
|
||||
remote_rank,
|
||||
tag,
|
||||
reinterpret_cast<void*>(local_buff),
|
||||
@@ -273,24 +276,29 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Run connection setup for MSCCLPP.")
|
||||
.def(
|
||||
"launch_proxy",
|
||||
[](_Comm& comm) {
|
||||
"launch_proxies",
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
return maybe(
|
||||
if (comm._proxies_running) {
|
||||
throw std::invalid_argument("Proxy Threads Already Running");
|
||||
}
|
||||
checkResult(
|
||||
mscclppProxyLaunch(comm._handle),
|
||||
true,
|
||||
"Failed to launch MSCCLPP proxy");
|
||||
comm._proxies_running = true;
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Start the MSCCLPP proxy.")
|
||||
.def(
|
||||
"stop_proxy",
|
||||
[](_Comm& comm) {
|
||||
"stop_proxies",
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
return maybe(
|
||||
mscclppProxyStop(comm._handle),
|
||||
true,
|
||||
"Failed to stop MSCCLPP proxy");
|
||||
if (comm._proxies_running) {
|
||||
checkResult(
|
||||
mscclppProxyStop(comm._handle),
|
||||
"Failed to stop MSCCLPP proxy");
|
||||
comm._proxies_running = false;
|
||||
}
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Start the MSCCLPP proxy.")
|
||||
@@ -318,7 +326,7 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"all-gather ints over the bootstrap connection.")
|
||||
.def(
|
||||
"all_gather_bytes",
|
||||
[](_Comm& comm, nb::bytes& item) {
|
||||
[](_Comm& comm, nb::bytes& item) -> std::vector<nb::bytes> {
|
||||
// First, all-gather the sizes of all bytes.
|
||||
std::vector<size_t> sizes(comm._world_size);
|
||||
sizes[comm._rank] = item.size();
|
||||
|
||||
@@ -10,8 +10,6 @@ logger = logging.getLogger(__file__)
|
||||
|
||||
from . import _py_mscclpp
|
||||
|
||||
_py_mscclpp._setup()
|
||||
|
||||
__all__ = (
|
||||
"Comm",
|
||||
"MscclppUniqueId",
|
||||
@@ -156,7 +154,12 @@ class Comm:
|
||||
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
|
||||
|
||||
def connect(
|
||||
self, remote_rank: int, tag: int, data_ptr, data_size, transport: int
|
||||
self,
|
||||
remote_rank: int,
|
||||
tag: int,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
transport: int,
|
||||
) -> None:
|
||||
self._comm.connect(
|
||||
remote_rank,
|
||||
@@ -168,3 +171,9 @@ class Comm:
|
||||
|
||||
def connection_setup(self) -> None:
|
||||
self._comm.connection_setup()
|
||||
|
||||
def launch_proxies(self) -> None:
|
||||
self._comm.launch_proxies()
|
||||
|
||||
def stop_proxies(self) -> None:
|
||||
self._comm.stop_proxies()
|
||||
|
||||
@@ -99,6 +99,10 @@ def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
|
||||
comm.connection_setup()
|
||||
|
||||
comm.launch_proxies()
|
||||
comm.stop_proxies()
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user