working on connect

This commit is contained in:
Crutcher Dunnavant
2023-04-04 23:34:27 +00:00
committed by Crutcher Dunnavant
parent b6ea0ca266
commit 7753c38eb1
4 changed files with 115 additions and 24 deletions

View File

@@ -136,6 +136,11 @@ NB_MODULE(_py_mscclpp, m) {
mscclppSetLogHandler(mscclppDefaultLogHandler);
});
nb::enum_<mscclppTransport_t>(m, "TransportType")
.value("P2P", mscclppTransport_t::mscclppTransportP2P)
.value("SHM", mscclppTransport_t::mscclppTransportSHM)
.value("IB", mscclppTransport_t::mscclppTransportIB);
nb::class_<mscclppUniqueId>(m, "MscclppUniqueId")
.def_ro_static("__doc__", &DOC_MscclppUniqueId)
.def_static(
@@ -215,6 +220,33 @@ NB_MODULE(_py_mscclpp, m) {
"Is this comm object closed?")
.def_ro("rank", &_Comm::_rank)
.def_ro("world_size", &_Comm::_world_size)
.def(
"connect",
[](_Comm& self,
int remote_rank,
int tag,
uint64_t local_buff,
uint64_t buff_size,
mscclppTransport_t transport_type) -> void {
checkResult(
mscclppConnect(
self._handle,
remote_rank,
tag,
reinterpret_cast<void*>(local_buff),
buff_size,
transport_type,
0 // ibDev
),
"Connect failed");
},
"remote_rank"_a,
"tag"_a,
"local_buf"_a,
"buff_size"_a,
"transport_type"_a,
nb::call_guard<nb::gil_scoped_release>(),
"Attach a local buffer to a remote connection.")
.def(
"connection_setup",
[](_Comm& comm) {
@@ -222,7 +254,7 @@ NB_MODULE(_py_mscclpp, m) {
return maybe(
mscclppConnectionSetup(comm._handle),
true,
"Failed to settup MSCCLPP connection");
"Failed to setup MSCCLPP connection");
},
nb::call_guard<nb::gil_scoped_release>(),
"Run connection setup for MSCCLPP.")

View File

@@ -11,11 +11,14 @@ logger = logging.getLogger(__file__)
from . import _py_mscclpp
__all__ = (
"Comm",
"MscclppUniqueId",
"MSCCLPP_UNIQUE_ID_BYTES",
"TransportType",
)
_Comm = _py_mscclpp._Comm
TransportType = _py_mscclpp.TransportType
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES
@@ -100,8 +103,9 @@ class Comm:
def close(self) -> None:
"""Close the connection."""
self._comm.close()
self._comm = None
if self._comm:
self._comm.close()
self._comm = None
@property
def rank(self) -> int:
@@ -149,3 +153,17 @@ class Comm:
:return: a list of de-pickled objects. Note, the ret[rank] item will be a new copy.
"""
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
def connect(
self, remote_rank: int, tag: int, data_ptr, data_size, transport: int
) -> None:
self._comm.connect(
remote_rank,
tag,
data_ptr,
data_size,
transport,
)
def connection_setup(self) -> None:
self._comm.connection_setup()

View File

@@ -1,7 +1,9 @@
import argparse
import os
from dataclasses import dataclass
import hamcrest
import torch
import mscclpp
@@ -11,26 +13,7 @@ class Example:
rank: int
def main():
p = argparse.ArgumentParser()
p.add_argument("--rank", type=int, required=True)
p.add_argument("--world_size", type=int, required=True)
p.add_argument("--port", default=50000)
options = p.parse_args()
comm_options = dict(
address=f"127.0.0.1:{options.port}",
rank=options.rank,
world_size=options.world_size,
)
print(f"{comm_options=}", flush=True)
comm = mscclpp.Comm.init_rank_from_address(**comm_options)
# comm.connection_setup()
hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank))
hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size))
def _test_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.bootstrap_all_gather_int(options.rank + 42),
hamcrest.equal_to(
@@ -41,6 +24,8 @@ def main():
),
)
def _test_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.all_gather_bytes(b"abc" * (1 + options.rank)),
hamcrest.equal_to(
@@ -51,6 +36,8 @@ def main():
),
)
def _test_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.all_gather_json({"rank": options.rank}),
hamcrest.equal_to(
@@ -71,6 +58,8 @@ def main():
),
)
def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.all_gather_pickle(Example(rank=options.rank)),
hamcrest.equal_to(
@@ -81,7 +70,57 @@ def main():
),
)
comm.close()
comm.connection_setup()
def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm):
buf = torch.zeros(
[options.world_size], dtype=torch.int64, device="cuda"
).contiguous()
rank = options.rank
tag = 0
remote_rank = (options.rank + 1) % options.world_size
comm.connect(
remote_rank,
tag,
buf.data_ptr(),
buf.element_size() * buf.numel(),
mscclpp._py_mscclpp.TransportType.P2P,
)
comm.connection_setup()
def main():
p = argparse.ArgumentParser()
p.add_argument("--rank", type=int, required=True)
p.add_argument("--world_size", type=int, required=True)
p.add_argument("--port", default=50000)
options = p.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(options.rank)
comm_options = dict(
address=f"127.0.0.1:{options.port}",
rank=options.rank,
world_size=options.world_size,
)
print(f"{comm_options=}", flush=True)
comm = mscclpp.Comm.init_rank_from_address(**comm_options)
# comm.connection_setup()
hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank))
hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size))
try:
_test_allgather_int(options, comm)
_test_allgather_bytes(options, comm)
_test_allgather_json(options, comm)
_test_allgather_pickle(options, comm)
_test_allgather_torch(options, comm)
finally:
comm.close()
if __name__ == "__main__":