all gather bytes, json, pickle

This commit is contained in:
Crutcher Dunnavant
2023-03-29 00:40:24 -07:00
committed by root
parent 17e1885981
commit 423affeaa6
3 changed files with 158 additions and 34 deletions

View File

@@ -3,6 +3,7 @@
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <algorithm>
#include <vector>
#include <cstdio>
#include <cstring>
@@ -73,7 +74,7 @@ Val maybe(mscclppResult_t status, Val val, const std::string& format, Args... ar
}
// Wrapper around connection state.
struct MscclppComm
struct _Comm
{
int _rank;
int _world_size;
@@ -81,10 +82,10 @@ struct MscclppComm
bool _is_open;
public:
MscclppComm(int rank, int world_size, mscclppComm_t handle)
_Comm(int rank, int world_size, mscclppComm_t handle)
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true) {}
~MscclppComm()
~_Comm()
{
close();
}
@@ -104,14 +105,14 @@ public:
void check_open()
{
if (!_is_open) {
throw std::invalid_argument("MscclppComm is not open");
throw std::invalid_argument("_Comm is not open");
}
}
};
static const std::string DOC_MscclppUniqueId = "MSCCLPP Unique Id; used by the MPI Interface";
static const std::string DOC_MscclppComm = "MSCCLPP Communications Handle";
static const std::string DOC__Comm = "MSCCLPP Communications Handle";
NB_MODULE(_py_mscclpp, m)
{
@@ -141,8 +142,8 @@ NB_MODULE(_py_mscclpp, m)
})
.def("bytes", [](mscclppUniqueId id) { return nb::bytes(id.internal, sizeof(id.internal)); });
nb::class_<MscclppComm>(m, "MscclppComm")
.def_ro_static("__doc__", &DOC_MscclppComm)
nb::class_<_Comm>(m, "_Comm")
.def_ro_static("__doc__", &DOC__Comm)
.def_static(
"init_rank_from_address",
[](const std::string& address, int rank, int world_size) {
@@ -153,7 +154,7 @@ NB_MODULE(_py_mscclpp, m)
address,
rank,
world_size);
return new MscclppComm(rank, world_size, handle);
return new _Comm(rank, world_size, handle);
},
nb::rv_policy::take_ownership,
nb::call_guard<nb::gil_scoped_release>(), "address"_a, "rank"_a, "world_size"_a,
@@ -168,60 +169,94 @@ NB_MODULE(_py_mscclpp, m)
id.internal,
rank,
world_size);
return new MscclppComm(rank, world_size, handle);
return new _Comm(rank, world_size, handle);
},
nb::rv_policy::take_ownership,
nb::call_guard<nb::gil_scoped_release>(), "id"_a, "rank"_a, "world_size"_a,
"Initialize comms given u UniqueID, rank, and world_size")
.def(
"opened", [](MscclppComm& comm) { return comm._is_open; }, "Is this comm object opened?")
"opened", [](_Comm& comm) { return comm._is_open; }, "Is this comm object opened?")
.def(
"closed", [](MscclppComm& comm) { return !comm._is_open; }, "Is this comm object closed?")
.def_ro( "rank", &MscclppComm::_rank)
.def_ro( "world_size", &MscclppComm::_world_size)
"closed", [](_Comm& comm) { return !comm._is_open; }, "Is this comm object closed?")
.def_ro( "rank", &_Comm::_rank)
.def_ro( "world_size", &_Comm::_world_size)
.def(
"connection_setup",
[](MscclppComm& comm) {
[](_Comm& comm) {
comm.check_open();
return maybe(mscclppConnectionSetup(comm._handle), true, "Failed to settup MSCCLPP connection");
},
nb::call_guard<nb::gil_scoped_release>(), "Run connection setup for MSCCLPP.")
.def(
"launch_proxy",
[](MscclppComm& comm) {
[](_Comm& comm) {
comm.check_open();
return maybe(mscclppProxyLaunch(comm._handle), true, "Failed to launch MSCCLPP proxy");
},
nb::call_guard<nb::gil_scoped_release>(), "Start the MSCCLPP proxy.")
.def(
"stop_proxy",
[](MscclppComm& comm) {
[](_Comm& comm) {
comm.check_open();
return maybe(mscclppProxyStop(comm._handle), true, "Failed to stop MSCCLPP proxy");
},
nb::call_guard<nb::gil_scoped_release>(), "Start the MSCCLPP proxy.")
.def("close", &MscclppComm::close, nb::call_guard<nb::gil_scoped_release>())
.def("__del__", &MscclppComm::close, nb::call_guard<nb::gil_scoped_release>())
.def("close", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
.def("__del__", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
.def("connection_setup",
[](MscclppComm& comm) -> void {
[](_Comm& comm) -> void {
comm.check_open();
checkResult(mscclppConnectionSetup(comm._handle), "Connection Setup Failed");
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"bootstrap_all_gather_int",
[](MscclppComm& comm, int val) -> std::vector<int> {
std::vector<int> buf(comm._world_size);
buf[comm._rank] = val;
mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int));
return buf;
[](_Comm& comm, int val) -> std::vector<int> {
std::vector<int> buf(comm._world_size);
buf[comm._rank] = val;
mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int));
return buf;
},
nb::call_guard<nb::gil_scoped_release>())
nb::call_guard<nb::gil_scoped_release>(),
"val"_a,
"all-gather ints over the bootstrap connection.")
.def(
"bootstrap_all_gather",
[](MscclppComm& comm, void* data, int size) {
comm.check_open();
return maybe(mscclppBootstrapAllGather(comm._handle, data, size), true, "Failed to stop MSCCLPP proxy");
},
nb::call_guard<nb::gil_scoped_release>());
"all_gather_bytes",
[](_Comm& comm, nb::bytes& item) {
// First, all-gather the sizes of all bytes.
std::vector<size_t> sizes(comm._world_size);
sizes[comm._rank] = item.size();
checkResult(
mscclppBootstrapAllGather(comm._handle, sizes.data(), sizeof(size_t)),
"bootstrapAllGather failed.");
// Next, find the largest message to send.
size_t max_size = *std::max_element(sizes.begin(), sizes.end());
// Allocate an all-gather buffer large enough for max * world_size.
std::shared_ptr<char[]> data_buf(new char[max_size * comm._world_size]);
// Copy the local item into the buffer.
std::memcpy(
&data_buf[comm._rank * max_size],
item.c_str(),
item.size());
// all-gather the data buffer.
checkResult(
mscclppBootstrapAllGather(comm._handle, data_buf.get(), max_size),
"bootstrapAllGather failed.");
// Build a response vector.
std::vector<nb::bytes> ret;
for (int i = 0; i < comm._world_size; ++i) {
// Copy out the relevant range of each item.
ret.push_back(nb::bytes(&data_buf[i * max_size], sizes[i]));
}
return ret;
},
nb::call_guard<nb::gil_scoped_release>(),
"item"_a,
"all-gather bytes over the bootstrap connection; sizes do not need to match."
);
}