mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
allocation fixes
This commit is contained in:
@@ -78,9 +78,12 @@ struct MscclppComm
|
||||
int _rank;
|
||||
int _world_size;
|
||||
mscclppComm_t _handle;
|
||||
bool _is_open = false;
|
||||
bool _is_open;
|
||||
|
||||
public:
|
||||
MscclppComm(int rank, int world_size, mscclppComm_t handle)
|
||||
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true) {}
|
||||
|
||||
~MscclppComm()
|
||||
{
|
||||
close();
|
||||
@@ -91,8 +94,10 @@ public:
|
||||
{
|
||||
if (_is_open) {
|
||||
checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel");
|
||||
_handle = 0;
|
||||
_handle = NULL;
|
||||
_is_open = false;
|
||||
_rank = -1;
|
||||
_world_size = -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,25 +146,31 @@ NB_MODULE(_py_mscclpp, m)
|
||||
.def_static(
|
||||
"init_rank_from_address",
|
||||
[](const std::string& address, int rank, int world_size) {
|
||||
MscclppComm comm = {0};
|
||||
comm._rank = rank;
|
||||
comm._world_size = world_size;
|
||||
comm._is_open = true;
|
||||
return maybe(mscclppCommInitRank(&comm._handle, world_size, address.c_str(), rank), comm,
|
||||
"Failed to initialize comms: %s rank=%d world_size=%d", address, rank, world_size);
|
||||
mscclppComm_t handle;
|
||||
checkResult(
|
||||
mscclppCommInitRank(&handle, world_size, address.c_str(), rank),
|
||||
"Failed to initialize comms: %s rank=%d world_size=%d",
|
||||
address,
|
||||
rank,
|
||||
world_size);
|
||||
return new MscclppComm(rank, world_size, handle);
|
||||
},
|
||||
nb::rv_policy::take_ownership,
|
||||
nb::call_guard<nb::gil_scoped_release>(), "address"_a, "rank"_a, "world_size"_a,
|
||||
"Initialize comms given an IP address, rank, and world_size")
|
||||
.def_static(
|
||||
"init_rank_from_id",
|
||||
[](const mscclppUniqueId& id, int rank, int world_size) {
|
||||
MscclppComm comm = {0};
|
||||
comm._rank = rank;
|
||||
comm._world_size = world_size;
|
||||
comm._is_open = true;
|
||||
return maybe(mscclppCommInitRankFromId(&comm._handle, world_size, id, rank), comm,
|
||||
"Failed to initialize comms: %02X%s rank=%d world_size=%d", id.internal, rank, world_size);
|
||||
mscclppComm_t handle;
|
||||
checkResult(
|
||||
mscclppCommInitRankFromId(&handle, world_size, id, rank),
|
||||
"Failed to initialize comms: %02X%s rank=%d world_size=%d",
|
||||
id.internal,
|
||||
rank,
|
||||
world_size);
|
||||
return new MscclppComm(rank, world_size, handle);
|
||||
},
|
||||
nb::rv_policy::take_ownership,
|
||||
nb::call_guard<nb::gil_scoped_release>(), "id"_a, "rank"_a, "world_size"_a,
|
||||
"Initialize comms given u UniqueID, rank, and world_size")
|
||||
.def(
|
||||
@@ -202,8 +213,7 @@ NB_MODULE(_py_mscclpp, m)
|
||||
[](MscclppComm& comm, int val) -> std::vector<int> {
|
||||
std::vector<int> buf(comm._world_size);
|
||||
buf[comm._rank] = val;
|
||||
// this call segfaults; disabling this call does not.
|
||||
checkResult(mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int)), "All Gather Failed");
|
||||
mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int));
|
||||
return buf;
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
|
||||
@@ -69,9 +69,7 @@ class CommsTest(unittest.TestCase):
|
||||
try:
|
||||
f.result()
|
||||
except subprocess.CalledProcessError as e:
|
||||
errors.append(f"{rank=}: " + e.output.decode('utf-8'))
|
||||
errors.append(e.output)
|
||||
|
||||
if errors:
|
||||
raise AssertionError("\n\n".join(errors))
|
||||
|
||||
|
||||
raise AssertionError("\n\n".join(e.decode('utf-8', errors='ignore') for e in errors))
|
||||
|
||||
@@ -30,9 +30,6 @@ def main():
|
||||
]),
|
||||
)
|
||||
|
||||
buf = bytearray(world_size)
|
||||
buf[rank] = rank
|
||||
|
||||
comm.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user