Fix a pytest bug (#196)

This commit is contained in:
Saeed Maleki
2023-10-13 01:39:43 -07:00
committed by GitHub
parent 8c0f9e84d0
commit 148681b4bc
2 changed files with 7 additions and 3 deletions

View File

@@ -123,7 +123,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
# this test starts with a random tensor on rank 0 and rotates it all the way through all ranks
@@ -139,6 +139,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
memory_expected = memory.copy()
else:
memory = xp.zeros(nelem, dtype=xp.float32)
if device == "cuda":
cp.cuda.runtime.deviceSynchronize()
signal_memory = xp.zeros(1, dtype=xp.int64)
all_reg_memories = group.register_tensor_with_connections(memory, connections)
@@ -156,6 +158,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
connections[next_rank].flush()
if group.my_rank == 0:
memory[:] = 0
if device == "cuda":
cp.cuda.runtime.deviceSynchronize()
connections[next_rank].update_and_sync(
all_signal_memories[next_rank], 0, dummy_memory_on_cpu.ctypes.data, signal_val
)