[python] working on bootstrap all gather bug

This commit is contained in:
Crutcher Dunnavant
2023-03-28 14:32:43 -07:00
committed by root
parent 2c6460ce72
commit 8cac41c8ac
5 changed files with 95 additions and 46 deletions

View File

@@ -1,7 +1,9 @@
#include <mscclpp.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <vector>
#include <cstdio>
#include <cstring>
#include <memory>
@@ -52,12 +54,13 @@ template <typename... Args> void checkResult(mscclppResult_t status, const std::
case mscclppRemoteError:
case mscclppInProgress:
case mscclppNumResults:
// throw std::runtime_error(string_format(format, args...) + " : " + std::string(mscclppGetErrorString(status)));
throw std::runtime_error(string_format(format, args...));
case mscclppInvalidArgument:
case mscclppInvalidUsage:
default:
throw std::invalid_argument(string_format(format, args...));
throw std::invalid_argument(string_format(format, args...));
}
}
@@ -72,6 +75,8 @@ Val maybe(mscclppResult_t status, Val val, const std::string& format, Args... ar
// Wrapper around connection state.
struct MscclppComm
{
int _rank;
int _world_size;
mscclppComm_t _handle;
bool _is_open = false;
@@ -137,6 +142,8 @@ NB_MODULE(_py_mscclpp, m)
"init_rank_from_address",
[](const std::string& address, int rank, int world_size) {
MscclppComm comm = {0};
comm._rank = rank;
comm._world_size = world_size;
comm._is_open = true;
return maybe(mscclppCommInitRank(&comm._handle, world_size, address.c_str(), rank), comm,
"Failed to initialize comms: %s rank=%d world_size=%d", address, rank, world_size);
@@ -147,6 +154,8 @@ NB_MODULE(_py_mscclpp, m)
"init_rank_from_id",
[](const mscclppUniqueId& id, int rank, int world_size) {
MscclppComm comm = {0};
comm._rank = rank;
comm._world_size = world_size;
comm._is_open = true;
return maybe(mscclppCommInitRankFromId(&comm._handle, world_size, id, rank), comm,
"Failed to initialize comms: %02X%s rank=%d world_size=%d", id.internal, rank, world_size);
@@ -157,22 +166,8 @@ NB_MODULE(_py_mscclpp, m)
"opened", [](MscclppComm& comm) { return comm._is_open; }, "Is this comm object opened?")
.def(
"closed", [](MscclppComm& comm) { return !comm._is_open; }, "Is this comm object closed?")
.def(
"rank",
[](MscclppComm& comm) {
comm.check_open();
int rank;
return maybe(mscclppCommRank(comm._handle, &rank), rank, "Failed to retrieve MSCCLPP rank");
},
nb::call_guard<nb::gil_scoped_release>(), "The rank of this node.")
.def(
"size",
[](MscclppComm& comm) {
comm.check_open();
int size;
return maybe(mscclppCommSize(comm._handle, &size), size, "Failed to retrieve MSCCLPP world size");
},
nb::call_guard<nb::gil_scoped_release>(), "The world size of this node.")
.def_ro( "rank", &MscclppComm::_rank)
.def_ro( "world_size", &MscclppComm::_world_size)
.def(
"connection_setup",
[](MscclppComm& comm) {
@@ -196,6 +191,22 @@ NB_MODULE(_py_mscclpp, m)
nb::call_guard<nb::gil_scoped_release>(), "Start the MSCCLPP proxy.")
.def("close", &MscclppComm::close, nb::call_guard<nb::gil_scoped_release>())
.def("__del__", &MscclppComm::close, nb::call_guard<nb::gil_scoped_release>())
.def("connection_setup",
[](MscclppComm& comm) -> void {
comm.check_open();
checkResult(mscclppConnectionSetup(comm._handle), "Connection Setup Failed");
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"bootstrap_all_gather_int",
[](MscclppComm& comm, int val) -> std::vector<int> {
std::vector<int> buf(comm._world_size);
buf[comm._rank] = val;
// this call segfaults; disabling this call does not.
checkResult(mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int)), "All Gather Failed");
return buf;
},
nb::call_guard<nb::gil_scoped_release>())
.def(
"bootstrap_all_gather",
[](MscclppComm& comm, void* data, int size) {