mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
working on connect
This commit is contained in:
committed by
Crutcher Dunnavant
parent
b6ea0ca266
commit
7753c38eb1
@@ -136,6 +136,11 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
mscclppSetLogHandler(mscclppDefaultLogHandler);
|
||||
});
|
||||
|
||||
nb::enum_<mscclppTransport_t>(m, "TransportType")
|
||||
.value("P2P", mscclppTransport_t::mscclppTransportP2P)
|
||||
.value("SHM", mscclppTransport_t::mscclppTransportSHM)
|
||||
.value("IB", mscclppTransport_t::mscclppTransportIB);
|
||||
|
||||
nb::class_<mscclppUniqueId>(m, "MscclppUniqueId")
|
||||
.def_ro_static("__doc__", &DOC_MscclppUniqueId)
|
||||
.def_static(
|
||||
@@ -215,6 +220,33 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"Is this comm object closed?")
|
||||
.def_ro("rank", &_Comm::_rank)
|
||||
.def_ro("world_size", &_Comm::_world_size)
|
||||
.def(
|
||||
"connect",
|
||||
[](_Comm& self,
|
||||
int remote_rank,
|
||||
int tag,
|
||||
uint64_t local_buff,
|
||||
uint64_t buff_size,
|
||||
mscclppTransport_t transport_type) -> void {
|
||||
checkResult(
|
||||
mscclppConnect(
|
||||
self._handle,
|
||||
remote_rank,
|
||||
tag,
|
||||
reinterpret_cast<void*>(local_buff),
|
||||
buff_size,
|
||||
transport_type,
|
||||
0 // ibDev
|
||||
),
|
||||
"Connect failed");
|
||||
},
|
||||
"remote_rank"_a,
|
||||
"tag"_a,
|
||||
"local_buf"_a,
|
||||
"buff_size"_a,
|
||||
"transport_type"_a,
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Attach a local buffer to a remote connection.")
|
||||
.def(
|
||||
"connection_setup",
|
||||
[](_Comm& comm) {
|
||||
@@ -222,7 +254,7 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
return maybe(
|
||||
mscclppConnectionSetup(comm._handle),
|
||||
true,
|
||||
"Failed to settup MSCCLPP connection");
|
||||
"Failed to setup MSCCLPP connection");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Run connection setup for MSCCLPP.")
|
||||
|
||||
@@ -11,11 +11,14 @@ logger = logging.getLogger(__file__)
|
||||
from . import _py_mscclpp
|
||||
|
||||
__all__ = (
|
||||
"Comm",
|
||||
"MscclppUniqueId",
|
||||
"MSCCLPP_UNIQUE_ID_BYTES",
|
||||
"TransportType",
|
||||
)
|
||||
|
||||
_Comm = _py_mscclpp._Comm
|
||||
TransportType = _py_mscclpp.TransportType
|
||||
|
||||
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
|
||||
MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES
|
||||
@@ -100,8 +103,9 @@ class Comm:
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
self._comm.close()
|
||||
self._comm = None
|
||||
if self._comm:
|
||||
self._comm.close()
|
||||
self._comm = None
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
@@ -149,3 +153,17 @@ class Comm:
|
||||
:return: a list of de-pickled objects. Note, the ret[rank] item will be a new copy.
|
||||
"""
|
||||
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
|
||||
|
||||
def connect(
|
||||
self, remote_rank: int, tag: int, data_ptr, data_size, transport: int
|
||||
) -> None:
|
||||
self._comm.connect(
|
||||
remote_rank,
|
||||
tag,
|
||||
data_ptr,
|
||||
data_size,
|
||||
transport,
|
||||
)
|
||||
|
||||
def connection_setup(self) -> None:
|
||||
self._comm.connection_setup()
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import hamcrest
|
||||
import torch
|
||||
|
||||
import mscclpp
|
||||
|
||||
@@ -11,26 +13,7 @@ class Example:
|
||||
rank: int
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--rank", type=int, required=True)
|
||||
p.add_argument("--world_size", type=int, required=True)
|
||||
p.add_argument("--port", default=50000)
|
||||
options = p.parse_args()
|
||||
|
||||
comm_options = dict(
|
||||
address=f"127.0.0.1:{options.port}",
|
||||
rank=options.rank,
|
||||
world_size=options.world_size,
|
||||
)
|
||||
print(f"{comm_options=}", flush=True)
|
||||
|
||||
comm = mscclpp.Comm.init_rank_from_address(**comm_options)
|
||||
# comm.connection_setup()
|
||||
|
||||
hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank))
|
||||
hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size))
|
||||
|
||||
def _test_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.bootstrap_all_gather_int(options.rank + 42),
|
||||
hamcrest.equal_to(
|
||||
@@ -41,6 +24,8 @@ def main():
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _test_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_bytes(b"abc" * (1 + options.rank)),
|
||||
hamcrest.equal_to(
|
||||
@@ -51,6 +36,8 @@ def main():
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _test_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_json({"rank": options.rank}),
|
||||
hamcrest.equal_to(
|
||||
@@ -71,6 +58,8 @@ def main():
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_pickle(Example(rank=options.rank)),
|
||||
hamcrest.equal_to(
|
||||
@@ -81,7 +70,57 @@ def main():
|
||||
),
|
||||
)
|
||||
|
||||
comm.close()
|
||||
comm.connection_setup()
|
||||
|
||||
|
||||
def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
buf = torch.zeros(
|
||||
[options.world_size], dtype=torch.int64, device="cuda"
|
||||
).contiguous()
|
||||
rank = options.rank
|
||||
tag = 0
|
||||
remote_rank = (options.rank + 1) % options.world_size
|
||||
comm.connect(
|
||||
remote_rank,
|
||||
tag,
|
||||
buf.data_ptr(),
|
||||
buf.element_size() * buf.numel(),
|
||||
mscclpp._py_mscclpp.TransportType.P2P,
|
||||
)
|
||||
|
||||
comm.connection_setup()
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--rank", type=int, required=True)
|
||||
p.add_argument("--world_size", type=int, required=True)
|
||||
p.add_argument("--port", default=50000)
|
||||
options = p.parse_args()
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(options.rank)
|
||||
|
||||
comm_options = dict(
|
||||
address=f"127.0.0.1:{options.port}",
|
||||
rank=options.rank,
|
||||
world_size=options.world_size,
|
||||
)
|
||||
print(f"{comm_options=}", flush=True)
|
||||
|
||||
comm = mscclpp.Comm.init_rank_from_address(**comm_options)
|
||||
# comm.connection_setup()
|
||||
|
||||
hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank))
|
||||
hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size))
|
||||
|
||||
try:
|
||||
_test_allgather_int(options, comm)
|
||||
_test_allgather_bytes(options, comm)
|
||||
_test_allgather_json(options, comm)
|
||||
_test_allgather_pickle(options, comm)
|
||||
_test_allgather_torch(options, comm)
|
||||
finally:
|
||||
comm.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user