mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
cleanup tests
This commit is contained in:
@@ -6,7 +6,6 @@ import hamcrest
|
||||
import torch
|
||||
|
||||
import mscclpp
|
||||
import time
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -14,7 +13,7 @@ class Example:
|
||||
rank: int
|
||||
|
||||
|
||||
def _test_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
def _test_bootstrap_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.bootstrap_all_gather_int(options.rank + 42),
|
||||
hamcrest.equal_to(
|
||||
@@ -26,7 +25,7 @@ def _test_allgather_int(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
)
|
||||
|
||||
|
||||
def _test_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
def _test_bootstrap_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_bytes(b"abc" * (1 + options.rank)),
|
||||
hamcrest.equal_to(
|
||||
@@ -38,7 +37,7 @@ def _test_allgather_bytes(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
)
|
||||
|
||||
|
||||
def _test_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
def _test_bootstrap_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_json({"rank": options.rank}),
|
||||
hamcrest.equal_to(
|
||||
@@ -60,7 +59,7 @@ def _test_allgather_json(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
)
|
||||
|
||||
|
||||
def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
def _test_bootstrap_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
hamcrest.assert_that(
|
||||
comm.all_gather_pickle(Example(rank=options.rank)),
|
||||
hamcrest.equal_to(
|
||||
@@ -74,7 +73,7 @@ def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
comm.connection_setup()
|
||||
|
||||
|
||||
def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
rank = options.rank
|
||||
|
||||
buf = torch.zeros([options.world_size], dtype=torch.int64)
|
||||
@@ -97,7 +96,6 @@ def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# time.sleep(3)
|
||||
|
||||
comm.connection_setup()
|
||||
|
||||
@@ -125,11 +123,11 @@ def main():
|
||||
hamcrest.assert_that(comm.world_size, hamcrest.equal_to(options.world_size))
|
||||
|
||||
try:
|
||||
_test_allgather_int(options, comm)
|
||||
_test_allgather_bytes(options, comm)
|
||||
_test_allgather_json(options, comm)
|
||||
_test_allgather_pickle(options, comm)
|
||||
_test_allgather_torch(options, comm)
|
||||
_test_bootstrap_allgather_int(options, comm)
|
||||
_test_bootstrap_allgather_bytes(options, comm)
|
||||
_test_bootstrap_allgather_json(options, comm)
|
||||
_test_bootstrap_allgather_pickle(options, comm)
|
||||
_test_p2p_connect(options, comm)
|
||||
finally:
|
||||
comm.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user