cleanup tests

This commit is contained in:
Crutcher Dunnavant
2023-04-07 11:37:24 -07:00
parent 68eff98bbc
commit d014693288

View File

@@ -6,7 +6,6 @@ import hamcrest
import torch
import mscclpp
import time
@dataclass
@@ -14,7 +13,7 @@ class Example:
rank: int
def _test_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
def _test_bootstrap_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.bootstrap_all_gather_int(options.rank + 42),
hamcrest.equal_to(
@@ -26,7 +25,7 @@ def _test_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
)
def _test_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
def _test_bootstrap_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.all_gather_bytes(b"abc" * (1 + options.rank)),
hamcrest.equal_to(
@@ -38,7 +37,7 @@ def _test_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
)
def _test_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
def _test_bootstrap_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.all_gather_json({"rank": options.rank}),
hamcrest.equal_to(
@@ -60,7 +59,7 @@ def _test_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
)
def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
def _test_bootstrap_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
hamcrest.assert_that(
comm.all_gather_pickle(Example(rank=options.rank)),
hamcrest.equal_to(
@@ -74,7 +73,7 @@ def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
comm.connection_setup()
def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm):
def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm):
rank = options.rank
buf = torch.zeros([options.world_size], dtype=torch.int64)
@@ -97,7 +96,6 @@ def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm):
)
torch.cuda.synchronize()
# time.sleep(3)
comm.connection_setup()
@@ -125,11 +123,11 @@ def main():
hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size))
try:
_test_allgather_int(options, comm)
_test_allgather_bytes(options, comm)
_test_allgather_json(options, comm)
_test_allgather_pickle(options, comm)
_test_allgather_torch(options, comm)
_test_bootstrap_allgather_int(options, comm)
_test_bootstrap_allgather_bytes(options, comm)
_test_bootstrap_allgather_json(options, comm)
_test_bootstrap_allgather_pickle(options, comm)
_test_p2p_connect(options, comm)
finally:
comm.close()