update pytest and python API to fix ut failure (#598)

update pytest and python API to fix ut failure
This commit is contained in:
Binyang Li
2025-08-05 15:17:33 -07:00
committed by GitHub
parent 334b232e36
commit 658411ccc4
3 changed files with 30 additions and 29 deletions

View File

@@ -141,18 +141,18 @@ def test_bootstrap_init_gil_release(mpi_group: MpiGroup):
mpi_group.comm.barrier()
def create_connection(group: mscclpp_comm.CommGroup, transport: str):
if transport == "NVLS":
def create_connection(group: mscclpp_comm.CommGroup, connection_type: str):
if connection_type == "NVLS":
all_ranks = list(range(group.nranks))
tran = Transport.Nvls
connection = group.make_connection(all_ranks, tran)
tran = Transport.CudaIpc
connection = group.make_connection(all_ranks, tran, use_switch=True)
return connection
remote_nghrs = list(range(group.nranks))
remote_nghrs.remove(group.my_rank)
if transport == "NVLink":
if connection_type == "NVLink":
tran = Transport.CudaIpc
elif transport == "IB":
elif connection_type == "IB":
tran = group.my_ib_device(group.my_rank % 8)
else:
assert False
@@ -160,14 +160,14 @@ def create_connection(group: mscclpp_comm.CommGroup, transport: str):
return connections
def create_group_and_connection(mpi_group: MpiGroup, transport: str):
if (transport == "NVLink" or transport == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
def create_group_and_connection(mpi_group: MpiGroup, connection_type: str):
if (connection_type == "NVLink" or connection_type == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("cannot use nvlink/nvls for cross node")
group = mscclpp_comm.CommGroup(mpi_group.comm)
try:
connection = create_connection(group, transport)
connection = create_connection(group, connection_type)
except Error as e:
if transport == "IB" and e.args[0] == ErrorCode.InvalidUsage:
if connection_type == "IB" and e.args[0] == ErrorCode.InvalidUsage:
pytest.skip("IB not supported on this node")
raise
return group, connection
@@ -194,10 +194,10 @@ def test_gpu_buffer(mpi_group: MpiGroup, nelem: int, dtype: cp.dtype):
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@pytest.mark.parametrize("connection_type", ["IB", "NVLink"])
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
group, connections = create_group_and_connection(mpi_group, transport)
def test_connection_write(mpi_group: MpiGroup, connection_type: str, nelem: int):
group, connections = create_group_and_connection(mpi_group, connection_type)
memory = GpuBuffer(nelem, dtype=cp.int32)
nelemPerRank = nelem // group.nranks
sizePerRank = nelemPerRank * memory.itemsize
@@ -229,16 +229,16 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@pytest.mark.parametrize("connection_type", ["IB", "NVLink"])
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
def test_connection_write_and_signal(mpi_group: MpiGroup, connection_type: str, nelem: int, device: str):
# this test starts with a random tensor on rank 0 and rotates it all the way through all ranks
# and finally, comes back to rank 0 to make sure it matches all the original values
if device == "cpu" and transport == "NVLink":
if device == "cpu" and connection_type == "NVLink":
pytest.skip("nvlink doesn't work with host allocated memory")
group, connections = create_group_and_connection(mpi_group, transport)
group, connections = create_group_and_connection(mpi_group, connection_type)
xp = cp if device == "cuda" else np
if group.my_rank == 0:
memory = xp.random.randn(nelem)
@@ -339,7 +339,7 @@ def test_nvls_connection(mpi_group: MpiGroup):
pytest.skip("cannot use nvls for cross node")
group = mscclpp_comm.CommGroup(mpi_group.comm)
all_ranks = list(range(group.nranks))
nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
nvls_connection = group.make_connection(all_ranks, Transport.CudaIpc, use_switch=True)
memory1 = GpuBuffer(2**29, cp.int8)
memory2 = GpuBuffer(2**29, cp.int8)
memory3 = GpuBuffer(2**29, cp.int8)
@@ -449,13 +449,13 @@ class MscclppKernel:
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
@pytest.mark.parametrize("connection_type", ["NVLink", "IB"])
def test_h2d_semaphores(mpi_group: MpiGroup, connection_type: str):
def signal(semaphores):
for rank in semaphores:
semaphores[rank].signal()
group, connections = create_group_and_connection(mpi_group, transport)
group, connections = create_group_and_connection(mpi_group, connection_type)
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores)
@@ -530,9 +530,9 @@ def test_fifo(
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
group, connections = create_group_and_connection(mpi_group, transport)
@pytest.mark.parametrize("connection_type", ["IB", "NVLink"])
def test_proxy(mpi_group: MpiGroup, nelem: int, connection_type: str):
group, connections = create_group_and_connection(mpi_group, connection_type)
memory = GpuBuffer(nelem, dtype=cp.int32)
nelemPerRank = nelem // group.nranks
@@ -579,10 +579,10 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
@pytest.mark.parametrize("connection_type", ["NVLink", "IB"])
@pytest.mark.parametrize("use_packet", [False, True])
def test_port_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
group, connections = create_group_and_connection(mpi_group, transport)
def test_port_channel(mpi_group: MpiGroup, nelem: int, connection_type: str, use_packet: bool):
group, connections = create_group_and_connection(mpi_group, connection_type)
memory = GpuBuffer(nelem, dtype=cp.int32)
if use_packet: