From 423affeaa67dfcf093c2f28644c72542f169a5b5 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Wed, 29 Mar 2023 00:40:24 -0700 Subject: [PATCH] all gather bytes, json, pickle --- python/src/_py_mscclpp.cpp | 97 +++++++++++++++------- python/src/mscclpp/__init__.py | 57 ++++++++++++- python/src/mscclpp/tests/bootstrap_test.py | 38 ++++++++- 3 files changed, 158 insertions(+), 34 deletions(-) diff --git a/python/src/_py_mscclpp.cpp b/python/src/_py_mscclpp.cpp index 1ddf215c..e8523c57 100644 --- a/python/src/_py_mscclpp.cpp +++ b/python/src/_py_mscclpp.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -73,7 +74,7 @@ Val maybe(mscclppResult_t status, Val val, const std::string& format, Args... ar } // Wrapper around connection state. -struct MscclppComm +struct _Comm { int _rank; int _world_size; @@ -81,10 +82,10 @@ struct MscclppComm bool _is_open; public: - MscclppComm(int rank, int world_size, mscclppComm_t handle) + _Comm(int rank, int world_size, mscclppComm_t handle) : _rank(rank), _world_size(world_size), _handle(handle), _is_open(true) {} - ~MscclppComm() + ~_Comm() { close(); } @@ -104,14 +105,14 @@ public: void check_open() { if (!_is_open) { - throw std::invalid_argument("MscclppComm is not open"); + throw std::invalid_argument("_Comm is not open"); } } }; static const std::string DOC_MscclppUniqueId = "MSCCLPP Unique Id; used by the MPI Interface"; -static const std::string DOC_MscclppComm = "MSCCLPP Communications Handle"; +static const std::string DOC__Comm = "MSCCLPP Communications Handle"; NB_MODULE(_py_mscclpp, m) { @@ -141,8 +142,8 @@ NB_MODULE(_py_mscclpp, m) }) .def("bytes", [](mscclppUniqueId id) { return nb::bytes(id.internal, sizeof(id.internal)); }); - nb::class_(m, "MscclppComm") - .def_ro_static("__doc__", &DOC_MscclppComm) + nb::class_<_Comm>(m, "_Comm") + .def_ro_static("__doc__", &DOC__Comm) .def_static( "init_rank_from_address", [](const std::string& address, int rank, int world_size) { @@ -153,7 +154,7 @@ NB_MODULE(_py_mscclpp, m) address, rank, world_size); - return new MscclppComm(rank, world_size, handle); + return new _Comm(rank, world_size, handle); }, nb::rv_policy::take_ownership, nb::call_guard(), "address"_a, "rank"_a, "world_size"_a, @@ -168,60 +169,94 @@ NB_MODULE(_py_mscclpp, m) id.internal, rank, world_size); - return new MscclppComm(rank, world_size, handle); + return new _Comm(rank, world_size, handle); }, nb::rv_policy::take_ownership, nb::call_guard(), "id"_a, "rank"_a, "world_size"_a, "Initialize comms given u UniqueID, rank, and world_size") .def( - "opened", [](MscclppComm& comm) { return comm._is_open; }, "Is this comm object opened?") + "opened", [](_Comm& comm) { return comm._is_open; }, "Is this comm object opened?") .def( - "closed", [](MscclppComm& comm) { return !comm._is_open; }, "Is this comm object closed?") - .def_ro( "rank", &MscclppComm::_rank) - .def_ro( "world_size", &MscclppComm::_world_size) + "closed", [](_Comm& comm) { return !comm._is_open; }, "Is this comm object closed?") + .def_ro( "rank", &_Comm::_rank) + .def_ro( "world_size", &_Comm::_world_size) .def( "connection_setup", - [](MscclppComm& comm) { + [](_Comm& comm) { comm.check_open(); return maybe(mscclppConnectionSetup(comm._handle), true, "Failed to settup MSCCLPP connection"); }, nb::call_guard(), "Run connection setup for MSCCLPP.") .def( "launch_proxy", - [](MscclppComm& comm) { + [](_Comm& comm) { comm.check_open(); return maybe(mscclppProxyLaunch(comm._handle), true, "Failed to launch MSCCLPP proxy"); }, nb::call_guard(), "Start the MSCCLPP proxy.") .def( "stop_proxy", - [](MscclppComm& comm) { + [](_Comm& comm) { comm.check_open(); return maybe(mscclppProxyStop(comm._handle), true, "Failed to stop MSCCLPP proxy"); }, nb::call_guard(), "Start the MSCCLPP proxy.") - .def("close", &MscclppComm::close, nb::call_guard()) - .def("__del__", &MscclppComm::close, nb::call_guard()) + .def("close", &_Comm::close, nb::call_guard()) + .def("__del__", &_Comm::close, nb::call_guard()) .def("connection_setup", - [](MscclppComm& comm) -> void { + [](_Comm& comm) -> void { comm.check_open(); checkResult(mscclppConnectionSetup(comm._handle), "Connection Setup Failed"); }, nb::call_guard()) .def( "bootstrap_all_gather_int", - [](MscclppComm& comm, int val) -> std::vector { - std::vector buf(comm._world_size); - buf[comm._rank] = val; - mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int)); - return buf; + [](_Comm& comm, int val) -> std::vector { + std::vector buf(comm._world_size); + buf[comm._rank] = val; + mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int)); + return buf; }, - nb::call_guard()) + nb::call_guard(), + "val"_a, + "all-gather ints over the bootstrap connection.") .def( - "bootstrap_all_gather", - [](MscclppComm& comm, void* data, int size) { - comm.check_open(); - return maybe(mscclppBootstrapAllGather(comm._handle, data, size), true, "Failed to stop MSCCLPP proxy"); - }, - nb::call_guard()); + "all_gather_bytes", + [](_Comm& comm, nb::bytes& item) { + // First, all-gather the sizes of all bytes. + std::vector sizes(comm._world_size); + sizes[comm._rank] = item.size(); + checkResult( + mscclppBootstrapAllGather(comm._handle, sizes.data(), sizeof(size_t)), + "bootstrapAllGather failed."); + + // Next, find the largest message to send. + size_t max_size = *std::max_element(sizes.begin(), sizes.end()); + + // Allocate an all-gather buffer large enough for max * world_size. + std::shared_ptr data_buf(new char[max_size * comm._world_size]); + + // Copy the local item into the buffer. + std::memcpy( + &data_buf[comm._rank * max_size], + item.c_str(), + item.size()); + + // all-gather the data buffer. + checkResult( + mscclppBootstrapAllGather(comm._handle, data_buf.get(), max_size), + "bootstrapAllGather failed."); + + // Build a response vector. + std::vector ret; + for (int i = 0; i < comm._world_size; ++i) { + // Copy out the relevant range of each item. + ret.push_back(nb::bytes(&data_buf[i * max_size], sizes[i])); + } + return ret; + }, + nb::call_guard(), + "item"_a, + "all-gather bytes over the bootstrap connection; sizes do not need to match." + ); } diff --git a/python/src/mscclpp/__init__.py b/python/src/mscclpp/__init__.py index e825b92d..b3a578d2 100644 --- a/python/src/mscclpp/__init__.py +++ b/python/src/mscclpp/__init__.py @@ -1,13 +1,66 @@ +from typing import Any +import json +import pickle + from . import _py_mscclpp __all__ = ( "MscclppUniqueId", "MSCCLPP_UNIQUE_ID_BYTES", - "MscclppComm", ) +_Comm = _py_mscclpp._Comm + MscclppUniqueId = _py_mscclpp.MscclppUniqueId MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES -MscclppComm = _py_mscclpp.MscclppComm +class Comm: + _comm: _Comm + + @staticmethod + def init_rank_from_address( + address: str, + rank: int, + world_size: int, + ): + return Comm( + _comm = _Comm.init_rank_from_address( + address=address, + rank=rank, + world_size=world_size, + ), + ) + + def __init__(self, *, _comm: _Comm): + self._comm = _comm + + def close(self) -> None: + self._comm.close() + self._comm = None + + @property + def rank(self) -> int: + return self._comm.rank + + @property + def world_size(self) -> int: + return self._comm.world_size + + def bootstrap_all_gather_int(self, val: int) -> list[int]: + return self._comm.bootstrap_all_gather_int(val) + + def all_gather_bytes(self, item: bytes) -> list[bytes]: + return self._comm.all_gather_bytes(item) + + def all_gather_json(self, item: Any) -> list[Any]: + return [ + json.loads(b.decode('utf-8')) + for b in self.all_gather_bytes(bytes(json.dumps(item), 'utf-8')) + ] + + def all_gather_pickle(self, item: Any) -> list[Any]: + return [ + pickle.loads(b) + for b in self.all_gather_bytes(pickle.dumps(item)) + ] diff --git a/python/src/mscclpp/tests/bootstrap_test.py b/python/src/mscclpp/tests/bootstrap_test.py index bb17bf15..a716f803 100644 --- a/python/src/mscclpp/tests/bootstrap_test.py +++ b/python/src/mscclpp/tests/bootstrap_test.py @@ -1,6 +1,10 @@ +from dataclasses import dataclass import argparse import hamcrest import mscclpp +@dataclass +class Example: + rank: int def main(): p = argparse.ArgumentParser() @@ -16,7 +20,7 @@ def main(): ) print(f'{comm_options=}', flush=True) - comm = mscclpp.MscclppComm.init_rank_from_address(**comm_options) + comm = mscclpp.Comm.init_rank_from_address(**comm_options) # comm.connection_setup() hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank)) @@ -30,6 +34,38 @@ def main(): ]), ) + hamcrest.assert_that( + comm.all_gather_bytes(b'abc' * (1 + options.rank)), + hamcrest.equal_to([ + b'abc', + b'abcabc', + ]), + ) + + hamcrest.assert_that( + comm.all_gather_json({'rank': options.rank}), + hamcrest.equal_to([ + {'rank': 0}, + {'rank': 1}, + ]), + ) + + hamcrest.assert_that( + comm.all_gather_json([options.rank, 42]), + hamcrest.equal_to([ + [0, 42], + [1, 42], + ]), + ) + + hamcrest.assert_that( + comm.all_gather_pickle(Example(rank=options.rank)), + hamcrest.equal_to([ + Example(rank=0), + Example(rank=1), + ]), + ) + comm.close() if __name__ == '__main__':