mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Fix a pytest bug (#196)
This commit is contained in:
@@ -123,7 +123,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
|
||||
# this test starts with a random tensor on rank 0 and rotates it all the way through all ranks
|
||||
@@ -139,6 +139,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
|
||||
memory_expected = memory.copy()
|
||||
else:
|
||||
memory = xp.zeros(nelem, dtype=xp.float32)
|
||||
if device == "cuda":
|
||||
cp.cuda.runtime.deviceSynchronize()
|
||||
|
||||
signal_memory = xp.zeros(1, dtype=xp.int64)
|
||||
all_reg_memories = group.register_tensor_with_connections(memory, connections)
|
||||
@@ -156,6 +158,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
|
||||
connections[next_rank].flush()
|
||||
if group.my_rank == 0:
|
||||
memory[:] = 0
|
||||
if device == "cuda":
|
||||
cp.cuda.runtime.deviceSynchronize()
|
||||
connections[next_rank].update_and_sync(
|
||||
all_signal_memories[next_rank], 0, dummy_memory_on_cpu.ctypes.data, signal_val
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user