mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
[python] working on bootstrap all gather bug
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
#include <mscclpp.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
@@ -52,12 +54,13 @@ template <typename... Args> void checkResult(mscclppResult_t status, const std::
|
||||
case mscclppRemoteError:
|
||||
case mscclppInProgress:
|
||||
case mscclppNumResults:
|
||||
// throw std::runtime_error(string_format(format, args...) + " : " + std::string(mscclppGetErrorString(status)));
|
||||
throw std::runtime_error(string_format(format, args...));
|
||||
|
||||
case mscclppInvalidArgument:
|
||||
case mscclppInvalidUsage:
|
||||
default:
|
||||
throw std::invalid_argument(string_format(format, args...));
|
||||
throw std::invalid_argument(string_format(format, args...));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +75,8 @@ Val maybe(mscclppResult_t status, Val val, const std::string& format, Args... ar
|
||||
// Wrapper around connection state.
|
||||
struct MscclppComm
|
||||
{
|
||||
int _rank;
|
||||
int _world_size;
|
||||
mscclppComm_t _handle;
|
||||
bool _is_open = false;
|
||||
|
||||
@@ -137,6 +142,8 @@ NB_MODULE(_py_mscclpp, m)
|
||||
"init_rank_from_address",
|
||||
[](const std::string& address, int rank, int world_size) {
|
||||
MscclppComm comm = {0};
|
||||
comm._rank = rank;
|
||||
comm._world_size = world_size;
|
||||
comm._is_open = true;
|
||||
return maybe(mscclppCommInitRank(&comm._handle, world_size, address.c_str(), rank), comm,
|
||||
"Failed to initialize comms: %s rank=%d world_size=%d", address, rank, world_size);
|
||||
@@ -147,6 +154,8 @@ NB_MODULE(_py_mscclpp, m)
|
||||
"init_rank_from_id",
|
||||
[](const mscclppUniqueId& id, int rank, int world_size) {
|
||||
MscclppComm comm = {0};
|
||||
comm._rank = rank;
|
||||
comm._world_size = world_size;
|
||||
comm._is_open = true;
|
||||
return maybe(mscclppCommInitRankFromId(&comm._handle, world_size, id, rank), comm,
|
||||
"Failed to initialize comms: %02X%s rank=%d world_size=%d", id.internal, rank, world_size);
|
||||
@@ -157,22 +166,8 @@ NB_MODULE(_py_mscclpp, m)
|
||||
"opened", [](MscclppComm& comm) { return comm._is_open; }, "Is this comm object opened?")
|
||||
.def(
|
||||
"closed", [](MscclppComm& comm) { return !comm._is_open; }, "Is this comm object closed?")
|
||||
.def(
|
||||
"rank",
|
||||
[](MscclppComm& comm) {
|
||||
comm.check_open();
|
||||
int rank;
|
||||
return maybe(mscclppCommRank(comm._handle, &rank), rank, "Failed to retrieve MSCCLPP rank");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(), "The rank of this node.")
|
||||
.def(
|
||||
"size",
|
||||
[](MscclppComm& comm) {
|
||||
comm.check_open();
|
||||
int size;
|
||||
return maybe(mscclppCommSize(comm._handle, &size), size, "Failed to retrieve MSCCLPP world size");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(), "The world size of this node.")
|
||||
.def_ro( "rank", &MscclppComm::_rank)
|
||||
.def_ro( "world_size", &MscclppComm::_world_size)
|
||||
.def(
|
||||
"connection_setup",
|
||||
[](MscclppComm& comm) {
|
||||
@@ -196,6 +191,22 @@ NB_MODULE(_py_mscclpp, m)
|
||||
nb::call_guard<nb::gil_scoped_release>(), "Start the MSCCLPP proxy.")
|
||||
.def("close", &MscclppComm::close, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("__del__", &MscclppComm::close, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("connection_setup",
|
||||
[](MscclppComm& comm) -> void {
|
||||
comm.check_open();
|
||||
checkResult(mscclppConnectionSetup(comm._handle), "Connection Setup Failed");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"bootstrap_all_gather_int",
|
||||
[](MscclppComm& comm, int val) -> std::vector<int> {
|
||||
std::vector<int> buf(comm._world_size);
|
||||
buf[comm._rank] = val;
|
||||
// this call segfaults; disabling this call does not.
|
||||
checkResult(mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int)), "All Gather Failed");
|
||||
return buf;
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"bootstrap_all_gather",
|
||||
[](MscclppComm& comm, void* data, int size) {
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
import concurrent.futures
|
||||
import unittest
|
||||
import hamcrest
|
||||
import subprocess
|
||||
|
||||
import mscclpp
|
||||
|
||||
MOD_DIR = os.path.dirname(__file__)
|
||||
TESTS_DIR = os.path.join(MOD_DIR, "tests")
|
||||
|
||||
class UniqueIdTest(unittest.TestCase):
|
||||
def test_no_constructor(self) -> None:
|
||||
@@ -39,41 +44,34 @@ class UniqueIdTest(unittest.TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def all_gather_task(rank: int, world_size: int) -> None:
|
||||
comm_options = dict(
|
||||
address="127.0.0.1:50000",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
print(f'{comm_options=}', flush=True)
|
||||
|
||||
comm = mscclpp.MscclppComm.init_rank_from_address(**comm_options)
|
||||
|
||||
buf = bytearray(world_size)
|
||||
buf[rank] = rank
|
||||
|
||||
if False:
|
||||
# crashes, bad call structure..
|
||||
comm.bootstrap_all_gather(memoryview(buf), world_size)
|
||||
hamcrest.assert_that(
|
||||
buf,
|
||||
hamcrest.equal_to(b'\000\002'),
|
||||
)
|
||||
|
||||
comm.close()
|
||||
|
||||
|
||||
class CommsTest(unittest.TestCase):
|
||||
def test_all_gather(self) -> None:
|
||||
world_size = 2
|
||||
|
||||
tasks: list[concurrent.futures.Future[None]] = []
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=world_size) as pool:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as pool:
|
||||
for rank in range(world_size):
|
||||
tasks.append(pool.submit(all_gather_task, rank, world_size))
|
||||
tasks.append(pool.submit(
|
||||
subprocess.check_output,
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"mscclpp.tests.bootstrap_test",
|
||||
f"--rank={rank}",
|
||||
f"--world_size={world_size}",
|
||||
],
|
||||
stderr=subprocess.STDOUT,
|
||||
))
|
||||
|
||||
for f in concurrent.futures.as_completed(tasks):
|
||||
f.result()
|
||||
errors = []
|
||||
for rank, f in enumerate(tasks):
|
||||
try:
|
||||
f.result()
|
||||
except subprocess.CalledProcessError as e:
|
||||
errors.append(f"{rank=}: " + e.output.decode('utf-8'))
|
||||
|
||||
if errors:
|
||||
raise AssertionError("\n\n".join(errors))
|
||||
|
||||
|
||||
|
||||
0
python/src/mscclpp/tests/__init__.py
Normal file
0
python/src/mscclpp/tests/__init__.py
Normal file
39
python/src/mscclpp/tests/bootstrap_test.py
Normal file
39
python/src/mscclpp/tests/bootstrap_test.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import argparse
|
||||
import hamcrest
|
||||
import mscclpp
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--rank", type=int, required=True)
|
||||
p.add_argument("--world_size", type=int, required=True)
|
||||
p.add_argument("--port", default=50000)
|
||||
options = p.parse_args()
|
||||
|
||||
comm_options = dict(
|
||||
address=f"127.0.0.1:{options.port}",
|
||||
rank=options.rank,
|
||||
world_size=options.world_size,
|
||||
)
|
||||
print(f'{comm_options=}', flush=True)
|
||||
|
||||
comm = mscclpp.MscclppComm.init_rank_from_address(**comm_options)
|
||||
# comm.connection_setup()
|
||||
|
||||
hamcrest.assert_that(comm.rank, hamcrest.equal_to(options.rank))
|
||||
hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size))
|
||||
|
||||
hamcrest.assert_that(
|
||||
comm.bootstrap_all_gather_int(options.rank + 42),
|
||||
hamcrest.equal_to([
|
||||
42,
|
||||
43,
|
||||
]),
|
||||
)
|
||||
|
||||
buf = bytearray(world_size)
|
||||
buf[rank] = rank
|
||||
|
||||
comm.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user