mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
update pytest and python API to fix ut failure (#598)
update pytest and python API to fix ut failure
This commit is contained in:
@@ -141,18 +141,18 @@ def test_bootstrap_init_gil_release(mpi_group: MpiGroup):
|
||||
mpi_group.comm.barrier()
|
||||
|
||||
|
||||
def create_connection(group: mscclpp_comm.CommGroup, transport: str):
|
||||
if transport == "NVLS":
|
||||
def create_connection(group: mscclpp_comm.CommGroup, connection_type: str):
|
||||
if connection_type == "NVLS":
|
||||
all_ranks = list(range(group.nranks))
|
||||
tran = Transport.Nvls
|
||||
connection = group.make_connection(all_ranks, tran)
|
||||
tran = Transport.CudaIpc
|
||||
connection = group.make_connection(all_ranks, tran, use_switch=True)
|
||||
return connection
|
||||
|
||||
remote_nghrs = list(range(group.nranks))
|
||||
remote_nghrs.remove(group.my_rank)
|
||||
if transport == "NVLink":
|
||||
if connection_type == "NVLink":
|
||||
tran = Transport.CudaIpc
|
||||
elif transport == "IB":
|
||||
elif connection_type == "IB":
|
||||
tran = group.my_ib_device(group.my_rank % 8)
|
||||
else:
|
||||
assert False
|
||||
@@ -160,14 +160,14 @@ def create_connection(group: mscclpp_comm.CommGroup, transport: str):
|
||||
return connections
|
||||
|
||||
|
||||
def create_group_and_connection(mpi_group: MpiGroup, transport: str):
|
||||
if (transport == "NVLink" or transport == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
|
||||
def create_group_and_connection(mpi_group: MpiGroup, connection_type: str):
|
||||
if (connection_type == "NVLink" or connection_type == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvlink/nvls for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
try:
|
||||
connection = create_connection(group, transport)
|
||||
connection = create_connection(group, connection_type)
|
||||
except Error as e:
|
||||
if transport == "IB" and e.args[0] == ErrorCode.InvalidUsage:
|
||||
if connection_type == "IB" and e.args[0] == ErrorCode.InvalidUsage:
|
||||
pytest.skip("IB not supported on this node")
|
||||
raise
|
||||
return group, connection
|
||||
@@ -194,10 +194,10 @@ def test_gpu_buffer(mpi_group: MpiGroup, nelem: int, dtype: cp.dtype):
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@pytest.mark.parametrize("connection_type", ["IB", "NVLink"])
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
def test_connection_write(mpi_group: MpiGroup, connection_type: str, nelem: int):
|
||||
group, connections = create_group_and_connection(mpi_group, connection_type)
|
||||
memory = GpuBuffer(nelem, dtype=cp.int32)
|
||||
nelemPerRank = nelem // group.nranks
|
||||
sizePerRank = nelemPerRank * memory.itemsize
|
||||
@@ -229,16 +229,16 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@pytest.mark.parametrize("connection_type", ["IB", "NVLink"])
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
|
||||
def test_connection_write_and_signal(mpi_group: MpiGroup, connection_type: str, nelem: int, device: str):
|
||||
# this test starts with a random tensor on rank 0 and rotates it all the way through all ranks
|
||||
# and finally, comes back to rank 0 to make sure it matches all the original values
|
||||
|
||||
if device == "cpu" and transport == "NVLink":
|
||||
if device == "cpu" and connection_type == "NVLink":
|
||||
pytest.skip("nvlink doesn't work with host allocated memory")
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
group, connections = create_group_and_connection(mpi_group, connection_type)
|
||||
xp = cp if device == "cuda" else np
|
||||
if group.my_rank == 0:
|
||||
memory = xp.random.randn(nelem)
|
||||
@@ -339,7 +339,7 @@ def test_nvls_connection(mpi_group: MpiGroup):
|
||||
pytest.skip("cannot use nvls for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
all_ranks = list(range(group.nranks))
|
||||
nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
|
||||
nvls_connection = group.make_connection(all_ranks, Transport.CudaIpc, use_switch=True)
|
||||
memory1 = GpuBuffer(2**29, cp.int8)
|
||||
memory2 = GpuBuffer(2**29, cp.int8)
|
||||
memory3 = GpuBuffer(2**29, cp.int8)
|
||||
@@ -449,13 +449,13 @@ class MscclppKernel:
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
|
||||
@pytest.mark.parametrize("connection_type", ["NVLink", "IB"])
|
||||
def test_h2d_semaphores(mpi_group: MpiGroup, connection_type: str):
|
||||
def signal(semaphores):
|
||||
for rank in semaphores:
|
||||
semaphores[rank].signal()
|
||||
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
group, connections = create_group_and_connection(mpi_group, connection_type)
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores)
|
||||
@@ -530,9 +530,9 @@ def test_fifo(
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
@pytest.mark.parametrize("connection_type", ["IB", "NVLink"])
|
||||
def test_proxy(mpi_group: MpiGroup, nelem: int, connection_type: str):
|
||||
group, connections = create_group_and_connection(mpi_group, connection_type)
|
||||
|
||||
memory = GpuBuffer(nelem, dtype=cp.int32)
|
||||
nelemPerRank = nelem // group.nranks
|
||||
@@ -579,10 +579,10 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
@pytest.mark.parametrize("connection_type", ["NVLink", "IB"])
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_port_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
def test_port_channel(mpi_group: MpiGroup, nelem: int, connection_type: str, use_packet: bool):
|
||||
group, connections = create_group_and_connection(mpi_group, connection_type)
|
||||
|
||||
memory = GpuBuffer(nelem, dtype=cp.int32)
|
||||
if use_packet:
|
||||
|
||||
Reference in New Issue
Block a user