docs and format

This commit is contained in:
Crutcher Dunnavant
2023-04-04 18:55:08 +00:00
parent 659a88a767
commit 151b29f70c
5 changed files with 155 additions and 62 deletions

View File

@@ -1,5 +1,10 @@
#!/bin/bash
set -ex
isort src
black src
clang-format -style='{
"BasedOnStyle": "google",
"BinPackParameters": false,

View File

@@ -1,3 +1,7 @@
nanobind
black
isort
pytest
PyHamcrest
nanobind

View File

@@ -1,10 +1,10 @@
import os
import re
import atexit
from typing import Any
import json
import pickle
import logging
import os
import pickle
import re
from typing import Any, Optional, final
logger = logging.getLogger(__file__)
@@ -20,25 +20,49 @@ _Comm = _py_mscclpp._Comm
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES
def _mscclpp_log_cb(msg: str) -> None:
logging.info
"""Log callback hook called from inside _py_mscclpp."""
# Attempt to parse out the original log level:
level = logging.INFO
if match := re.search(r'MSCCLPP (\w+)', msg):
if match := re.search(r"MSCCLPP (\w+)", msg):
level = logging._nameToLevel.get(match.group(1), logging.INFO)
# actually log the event.
logger.log(level, msg)
def _set_mscclpp_debug_log_level(level: str ='INFO'):
os.environ['MSCCLPP_DEBUG'] = level
def _setup_logging(level='INFO'):
_set_mscclpp_debug_log_level(level)
# The known log levels used by MSCCLPP.
# Set in os.environ['MSCCLPP_DEBUG'] and only parsed on first init.
MSCCLPP_LOG_LEVELS: set[str] = {
"DEBUG",
"INFO",
"WARN",
"ABORT",
"TRACE",
}
def _setup_logging(level: str = "INFO"):
"""Setup log hooks for the C library."""
level = level.upper()
if level not in MSCCLPP_LOG_LEVELS:
level = "INFO"
os.environ["MSCCLPP_DEBUG"] = level
_py_mscclpp._bind_log_handler(_mscclpp_log_cb)
# needed to prevent a segfault at exit.
atexit.register(_py_mscclpp._release_log_handler)
_setup_logging()
@final
class Comm:
"""Comm object; represents a mscclpp connection."""
_comm: _Comm
@staticmethod
@@ -46,9 +70,21 @@ class Comm:
address: str,
rank: int,
world_size: int,
*,
port: Optional[int] = None,
):
"""Initialize a Comm from an address.
:param address: the address as a string, with optional port.
:param rank: this Comm's rank.
:param world_size: the total world size.
:param port: (optional) port, appended to address.
:return: a newly initialized Comm.
"""
if port is not None:
address = f"{address}:{port}"
return Comm(
_comm = _Comm.init_rank_from_address(
_comm=_Comm.init_rank_from_address(
address=address,
rank=rank,
world_size=world_size,
@@ -56,35 +92,60 @@ class Comm:
)
def __init__(self, *, _comm: _Comm):
"""Construct a Comm object wrapping an internal _Comm handle."""
self._comm = _comm
def __del__(self) -> None:
self.close()
def close(self) -> None:
"""Close the connection."""
self._comm.close()
self._comm = None
@property
def rank(self) -> int:
"""Return the rank of the Comm.
Assumes the Comm is open.
"""
return self._comm.rank
@property
def world_size(self) -> int:
"""Return the world_size of the Comm.
Assumes the Comm is open.
"""
return self._comm.world_size
def bootstrap_all_gather_int(self, val: int) -> list[int]:
"""AllGather an int value through the bootstrap interface."""
return self._comm.bootstrap_all_gather_int(val)
def all_gather_bytes(self, item: bytes) -> list[bytes]:
"""AllGather bytes (of different sizes) through the bootstrap interface.
:param item: the bytes object for this rank.
:return: a list of bytes objects; the ret[rank] object will be a new copy.
"""
return self._comm.all_gather_bytes(item)
def all_gather_json(self, item: Any) -> list[Any]:
"""AllGather JSON objects through the bootstrap interface.
:param item: the JSON object for this rank.
:return: a list of JSON objects; the ret[rank] object will be a new copy.
"""
return [
json.loads(b.decode('utf-8'))
for b in self.all_gather_bytes(bytes(json.dumps(item), 'utf-8'))
json.loads(b.decode("utf-8"))
for b in self.all_gather_bytes(bytes(json.dumps(item), "utf-8"))
]
def all_gather_pickle(self, item: Any) -> list[Any]:
return [
pickle.loads(b)
for b in self.all_gather_bytes(pickle.dumps(item))
]
"""AllGather pickle-able objects through the bootstrap interface.
:param item: the object for this rank.
:return: a list of de-pickled objects. Note, the ret[rank] item will be a new copy.
"""
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]

View File

@@ -1,15 +1,17 @@
import os
import sys
import concurrent.futures
import unittest
import hamcrest
import os
import subprocess
import sys
import unittest
import hamcrest
import mscclpp
MOD_DIR = os.path.dirname(__file__)
TESTS_DIR = os.path.join(MOD_DIR, "tests")
class UniqueIdTest(unittest.TestCase):
def test_no_constructor(self) -> None:
hamcrest.assert_that(
@@ -37,13 +39,14 @@ class UniqueIdTest(unittest.TestCase):
# bad size
hamcrest.assert_that(
hamcrest.calling(mscclpp.MscclppUniqueId.from_bytes).with_args(b'abc'),
hamcrest.calling(mscclpp.MscclppUniqueId.from_bytes).with_args(b"abc"),
hamcrest.raises(
ValueError,
f"Requires exactly {mscclpp.MSCCLPP_UNIQUE_ID_BYTES} bytes; found 3"
f"Requires exactly {mscclpp.MSCCLPP_UNIQUE_ID_BYTES} bytes; found 3",
),
)
class CommsTest(unittest.TestCase):
def test_all_gather(self) -> None:
world_size = 2
@@ -52,17 +55,19 @@ class CommsTest(unittest.TestCase):
with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as pool:
for rank in range(world_size):
tasks.append(pool.submit(
subprocess.check_output,
[
"python",
"-m",
"mscclpp.tests.bootstrap_test",
f"--rank={rank}",
f"--world_size={world_size}",
],
stderr=subprocess.STDOUT,
))
tasks.append(
pool.submit(
subprocess.check_output,
[
"python",
"-m",
"mscclpp.tests.bootstrap_test",
f"--rank={rank}",
f"--world_size={world_size}",
],
stderr=subprocess.STDOUT,
)
)
errors = []
for rank, f in enumerate(tasks):
@@ -72,4 +77,6 @@ class CommsTest(unittest.TestCase):
errors.append(e.output)
if errors:
raise AssertionError("\n\n".join(e.decode('utf-8', errors='ignore') for e in errors))
raise AssertionError(
"\n\n".join(e.decode("utf-8", errors="ignore") for e in errors)
)

View File

@@ -1,11 +1,16 @@
from dataclasses import dataclass
import argparse
from dataclasses import dataclass
import hamcrest
import mscclpp
@dataclass
class Example:
rank: int
def main():
p = argparse.ArgumentParser()
p.add_argument("--rank", type=int, required=True)
@@ -18,7 +23,7 @@ def main():
rank=options.rank,
world_size=options.world_size,
)
print(f'{comm_options=}', flush=True)
print(f"{comm_options=}", flush=True)
comm = mscclpp.Comm.init_rank_from_address(**comm_options)
# comm.connection_setup()
@@ -28,45 +33,56 @@ def main():
hamcrest.assert_that(
comm.bootstrap_all_gather_int(options.rank + 42),
hamcrest.equal_to([
42,
43,
]),
hamcrest.equal_to(
[
42,
43,
]
),
)
hamcrest.assert_that(
comm.all_gather_bytes(b'abc' * (1 + options.rank)),
hamcrest.equal_to([
b'abc',
b'abcabc',
]),
comm.all_gather_bytes(b"abc" * (1 + options.rank)),
hamcrest.equal_to(
[
b"abc",
b"abcabc",
]
),
)
hamcrest.assert_that(
comm.all_gather_json({'rank': options.rank}),
hamcrest.equal_to([
{'rank': 0},
{'rank': 1},
]),
comm.all_gather_json({"rank": options.rank}),
hamcrest.equal_to(
[
{"rank": 0},
{"rank": 1},
]
),
)
hamcrest.assert_that(
comm.all_gather_json([options.rank, 42]),
hamcrest.equal_to([
[0, 42],
[1, 42],
]),
hamcrest.equal_to(
[
[0, 42],
[1, 42],
]
),
)
hamcrest.assert_that(
comm.all_gather_pickle(Example(rank=options.rank)),
hamcrest.equal_to([
Example(rank=0),
Example(rank=1),
]),
hamcrest.equal_to(
[
Example(rank=0),
Example(rank=1),
]
),
)
comm.close()
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()