allocation fixes

This commit is contained in:
Crutcher Dunnavant
2023-03-28 23:11:41 -07:00
committed by root
parent 8cac41c8ac
commit 17e1885981
3 changed files with 28 additions and 23 deletions

View File

@@ -78,9 +78,12 @@ struct MscclppComm
int _rank;
int _world_size;
mscclppComm_t _handle;
bool _is_open = false;
bool _is_open;
public:
MscclppComm(int rank, int world_size, mscclppComm_t handle)
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true) {}
~MscclppComm()
{
close();
@@ -91,8 +94,10 @@ public:
{
if (_is_open) {
checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel");
_handle = 0;
_handle = NULL;
_is_open = false;
_rank = -1;
_world_size = -1;
}
}
@@ -141,25 +146,31 @@ NB_MODULE(_py_mscclpp, m)
.def_static(
"init_rank_from_address",
[](const std::string& address, int rank, int world_size) {
MscclppComm comm = {0};
comm._rank = rank;
comm._world_size = world_size;
comm._is_open = true;
return maybe(mscclppCommInitRank(&comm._handle, world_size, address.c_str(), rank), comm,
"Failed to initialize comms: %s rank=%d world_size=%d", address, rank, world_size);
mscclppComm_t handle;
checkResult(
mscclppCommInitRank(&handle, world_size, address.c_str(), rank),
"Failed to initialize comms: %s rank=%d world_size=%d",
address,
rank,
world_size);
return new MscclppComm(rank, world_size, handle);
},
nb::rv_policy::take_ownership,
nb::call_guard<nb::gil_scoped_release>(), "address"_a, "rank"_a, "world_size"_a,
"Initialize comms given an IP address, rank, and world_size")
.def_static(
"init_rank_from_id",
[](const mscclppUniqueId& id, int rank, int world_size) {
MscclppComm comm = {0};
comm._rank = rank;
comm._world_size = world_size;
comm._is_open = true;
return maybe(mscclppCommInitRankFromId(&comm._handle, world_size, id, rank), comm,
"Failed to initialize comms: %02X%s rank=%d world_size=%d", id.internal, rank, world_size);
mscclppComm_t handle;
checkResult(
mscclppCommInitRankFromId(&handle, world_size, id, rank),
"Failed to initialize comms: %02X%s rank=%d world_size=%d",
id.internal,
rank,
world_size);
return new MscclppComm(rank, world_size, handle);
},
nb::rv_policy::take_ownership,
nb::call_guard<nb::gil_scoped_release>(), "id"_a, "rank"_a, "world_size"_a,
"Initialize comms given u UniqueID, rank, and world_size")
.def(
@@ -202,8 +213,7 @@ NB_MODULE(_py_mscclpp, m)
[](MscclppComm& comm, int val) -> std::vector<int> {
std::vector<int> buf(comm._world_size);
buf[comm._rank] = val;
// this call segfaults; disabling this call does not.
checkResult(mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int)), "All Gather Failed");
mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int));
return buf;
},
nb::call_guard<nb::gil_scoped_release>())

View File

@@ -69,9 +69,7 @@ class CommsTest(unittest.TestCase):
try:
f.result()
except subprocess.CalledProcessError as e:
errors.append(f"{rank=}: " + e.output.decode('utf-8'))
errors.append(e.output)
if errors:
raise AssertionError("\n\n".join(errors))
raise AssertionError("\n\n".join(e.decode('utf-8', errors='ignore') for e in errors))

View File

@@ -30,9 +30,6 @@ def main():
]),
)
buf = bytearray(world_size)
buf[rank] = rank
comm.close()
if __name__ == '__main__':