mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 14:29:13 +00:00
docs and format
This commit is contained in:
@@ -1,5 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
isort src
|
||||
black src
|
||||
|
||||
clang-format -style='{
|
||||
"BasedOnStyle": "google",
|
||||
"BinPackParameters": false,
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
nanobind
|
||||
black
|
||||
isort
|
||||
|
||||
pytest
|
||||
PyHamcrest
|
||||
|
||||
nanobind
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import re
|
||||
import atexit
|
||||
from typing import Any
|
||||
import json
|
||||
import pickle
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
from typing import Any, Optional, final
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
@@ -20,25 +20,49 @@ _Comm = _py_mscclpp._Comm
|
||||
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
|
||||
MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES
|
||||
|
||||
|
||||
def _mscclpp_log_cb(msg: str) -> None:
|
||||
logging.info
|
||||
"""Log callback hook called from inside _py_mscclpp."""
|
||||
|
||||
# Attempt to parse out the original log level:
|
||||
level = logging.INFO
|
||||
if match := re.search(r'MSCCLPP (\w+)', msg):
|
||||
if match := re.search(r"MSCCLPP (\w+)", msg):
|
||||
level = logging._nameToLevel.get(match.group(1), logging.INFO)
|
||||
|
||||
# actually log the event.
|
||||
logger.log(level, msg)
|
||||
|
||||
def _set_mscclpp_debug_log_level(level: str ='INFO'):
|
||||
os.environ['MSCCLPP_DEBUG'] = level
|
||||
|
||||
def _setup_logging(level='INFO'):
|
||||
_set_mscclpp_debug_log_level(level)
|
||||
# The known log levels used by MSCCLPP.
|
||||
# Set in os.environ['MSCCLPP_DEBUG'] and only parsed on first init.
|
||||
MSCCLPP_LOG_LEVELS: set[str] = {
|
||||
"DEBUG",
|
||||
"INFO",
|
||||
"WARN",
|
||||
"ABORT",
|
||||
"TRACE",
|
||||
}
|
||||
|
||||
|
||||
def _setup_logging(level: str = "INFO"):
|
||||
"""Setup log hooks for the C library."""
|
||||
level = level.upper()
|
||||
if level not in MSCCLPP_LOG_LEVELS:
|
||||
level = "INFO"
|
||||
os.environ["MSCCLPP_DEBUG"] = level
|
||||
|
||||
_py_mscclpp._bind_log_handler(_mscclpp_log_cb)
|
||||
# needed to prevent a segfault at exit.
|
||||
atexit.register(_py_mscclpp._release_log_handler)
|
||||
|
||||
|
||||
_setup_logging()
|
||||
|
||||
|
||||
@final
|
||||
class Comm:
|
||||
"""Comm object; represents a mscclpp connection."""
|
||||
|
||||
_comm: _Comm
|
||||
|
||||
@staticmethod
|
||||
@@ -46,9 +70,21 @@ class Comm:
|
||||
address: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
):
|
||||
"""Initialize a Comm from an address.
|
||||
|
||||
:param address: the address as a string, with optional port.
|
||||
:param rank: this Comm's rank.
|
||||
:param world_size: the total world size.
|
||||
:param port: (optional) port, appended to address.
|
||||
:return: a newly initialized Comm.
|
||||
"""
|
||||
if port is not None:
|
||||
address = f"{address}:{port}"
|
||||
return Comm(
|
||||
_comm = _Comm.init_rank_from_address(
|
||||
_comm=_Comm.init_rank_from_address(
|
||||
address=address,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@@ -56,35 +92,60 @@ class Comm:
|
||||
)
|
||||
|
||||
def __init__(self, *, _comm: _Comm):
|
||||
"""Construct a Comm object wrapping an internal _Comm handle."""
|
||||
self._comm = _comm
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
self._comm.close()
|
||||
self._comm = None
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
"""Return the rank of the Comm.
|
||||
|
||||
Assumes the Comm is open.
|
||||
"""
|
||||
return self._comm.rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
"""Return the world_size of the Comm.
|
||||
|
||||
Assumes the Comm is open.
|
||||
"""
|
||||
return self._comm.world_size
|
||||
|
||||
def bootstrap_all_gather_int(self, val: int) -> list[int]:
|
||||
"""AllGather an int value through the bootstrap interface."""
|
||||
return self._comm.bootstrap_all_gather_int(val)
|
||||
|
||||
def all_gather_bytes(self, item: bytes) -> list[bytes]:
|
||||
"""AllGather bytes (of different sizes) through the bootstrap interface.
|
||||
|
||||
:param item: the bytes object for this rank.
|
||||
:return: a list of bytes objects; the ret[rank] object will be a new copy.
|
||||
"""
|
||||
return self._comm.all_gather_bytes(item)
|
||||
|
||||
def all_gather_json(self, item: Any) -> list[Any]:
|
||||
"""AllGather JSON objects through the bootstrap interface.
|
||||
|
||||
:param item: the JSON object for this rank.
|
||||
:return: a list of JSON objects; the ret[rank] object will be a new copy.
|
||||
"""
|
||||
return [
|
||||
json.loads(b.decode('utf-8'))
|
||||
for b in self.all_gather_bytes(bytes(json.dumps(item), 'utf-8'))
|
||||
json.loads(b.decode("utf-8"))
|
||||
for b in self.all_gather_bytes(bytes(json.dumps(item), "utf-8"))
|
||||
]
|
||||
|
||||
def all_gather_pickle(self, item: Any) -> list[Any]:
|
||||
return [
|
||||
pickle.loads(b)
|
||||
for b in self.all_gather_bytes(pickle.dumps(item))
|
||||
]
|
||||
"""AllGather pickle-able objects through the bootstrap interface.
|
||||
|
||||
:param item: the object for this rank.
|
||||
:return: a list of de-pickled objects. Note, the ret[rank] item will be a new copy.
|
||||
"""
|
||||
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import os
|
||||
import sys
|
||||
import concurrent.futures
|
||||
import unittest
|
||||
import hamcrest
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import hamcrest
|
||||
|
||||
import mscclpp
|
||||
|
||||
MOD_DIR = os.path.dirname(__file__)
|
||||
TESTS_DIR = os.path.join(MOD_DIR, "tests")
|
||||
|
||||
|
||||
class UniqueIdTest(unittest.TestCase):
|
||||
def test_no_constructor(self) -> None:
|
||||
hamcrest.assert_that(
|
||||
@@ -37,13 +39,14 @@ class UniqueIdTest(unittest.TestCase):
|
||||
|
||||
# bad size
|
||||
hamcrest.assert_that(
|
||||
hamcrest.calling(mscclpp.MscclppUniqueId.from_bytes).with_args(b'abc'),
|
||||
hamcrest.calling(mscclpp.MscclppUniqueId.from_bytes).with_args(b"abc"),
|
||||
hamcrest.raises(
|
||||
ValueError,
|
||||
f"Requires exactly {mscclpp.MSCCLPP_UNIQUE_ID_BYTES} bytes; found 3"
|
||||
f"Requires exactly {mscclpp.MSCCLPP_UNIQUE_ID_BYTES} bytes; found 3",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CommsTest(unittest.TestCase):
|
||||
def test_all_gather(self) -> None:
|
||||
world_size = 2
|
||||
@@ -52,17 +55,19 @@ class CommsTest(unittest.TestCase):
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as pool:
|
||||
for rank in range(world_size):
|
||||
tasks.append(pool.submit(
|
||||
subprocess.check_output,
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"mscclpp.tests.bootstrap_test",
|
||||
f"--rank={rank}",
|
||||
f"--world_size={world_size}",
|
||||
],
|
||||
stderr=subprocess.STDOUT,
|
||||
))
|
||||
tasks.append(
|
||||
pool.submit(
|
||||
subprocess.check_output,
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"mscclpp.tests.bootstrap_test",
|
||||
f"--rank={rank}",
|
||||
f"--world_size={world_size}",
|
||||
],
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
)
|
||||
|
||||
errors = []
|
||||
for rank, f in enumerate(tasks):
|
||||
@@ -72,4 +77,6 @@ class CommsTest(unittest.TestCase):
|
||||
errors.append(e.output)
|
||||
|
||||
if errors:
|
||||
raise AssertionError("\n\n".join(e.decode('utf-8', errors='ignore') for e in errors))
|
||||
raise AssertionError(
|
||||
"\n\n".join(e.decode("utf-8", errors="ignore") for e in errors)
|
||||
)
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import hamcrest
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
@dataclass
|
||||
class Example:
|
||||
rank: int
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--rank", type=int, required=True)
|
||||
@@ -18,7 +23,7 @@ def main():
|
||||
rank=options.rank,
|
||||
world_size=options.world_size,
|
||||
)
|
||||
print(f'{comm_options=}', flush=True)
|
||||
print(f"{comm_options=}", flush=True)
|
||||
|
||||
comm = mscclpp.Comm.init_rank_from_address(**comm_options)
|
||||
# comm.connection_setup()
|
||||
@@ -28,45 +33,56 @@ def main():
|
||||
|
||||
hamcrest.assert_that(
|
||||
comm.bootstrap_all_gather_int(options.rank + 42),
|
||||
hamcrest.equal_to([
|
||||
42,
|
||||
43,
|
||||
]),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
42,
|
||||
43,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_bytes(b'abc' * (1 + options.rank)),
|
||||
hamcrest.equal_to([
|
||||
b'abc',
|
||||
b'abcabc',
|
||||
]),
|
||||
comm.all_gather_bytes(b"abc" * (1 + options.rank)),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
b"abc",
|
||||
b"abcabc",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_json({'rank': options.rank}),
|
||||
hamcrest.equal_to([
|
||||
{'rank': 0},
|
||||
{'rank': 1},
|
||||
]),
|
||||
comm.all_gather_json({"rank": options.rank}),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
{"rank": 0},
|
||||
{"rank": 1},
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_json([options.rank, 42]),
|
||||
hamcrest.equal_to([
|
||||
[0, 42],
|
||||
[1, 42],
|
||||
]),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
[0, 42],
|
||||
[1, 42],
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_pickle(Example(rank=options.rank)),
|
||||
hamcrest.equal_to([
|
||||
Example(rank=0),
|
||||
Example(rank=1),
|
||||
]),
|
||||
hamcrest.equal_to(
|
||||
[
|
||||
Example(rank=0),
|
||||
Example(rank=1),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
comm.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user