mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
More bindings and better test parametrization
This commit is contained in:
committed by
Saeed Maleki
parent
629949498b
commit
69938b47aa
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import netifaces as ni
|
||||
import pytest
|
||||
|
||||
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport
|
||||
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport, get_ib_device_count
|
||||
from ._cpp import _ext
|
||||
from .mscclpp_group import MscclppGroup
|
||||
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
@@ -17,6 +17,19 @@ from .utils import KernelBuilder, pack
|
||||
|
||||
ethernet_interface_name = "eth0"
|
||||
|
||||
skipif_ib = pytest.mark.skipif(get_ib_device_count() == 0, reason="no IB device")
|
||||
|
||||
def parametrize_transport(*transports: list):
|
||||
def decorator(func):
|
||||
params = []
|
||||
for transport in transports:
|
||||
if transport == "IB":
|
||||
params.append(pytest.param(transport, marks=skipif_ib))
|
||||
else:
|
||||
params.append(transport)
|
||||
return pytest.mark.parametrize("transport", params)(func)
|
||||
return decorator
|
||||
|
||||
|
||||
def all_ranks_on_the_same_node(mpi_group: MpiGroup):
|
||||
if (ethernet_interface_name in ni.interfaces()) is False:
|
||||
@@ -81,13 +94,13 @@ def create_and_connect(mpi_group: MpiGroup, transport: str):
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@parametrize_transport("IB", "NVLink")
|
||||
def test_group_with_connections(mpi_group: MpiGroup, transport: str):
|
||||
create_and_connect(mpi_group, transport)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@parametrize_transport("IB", "NVLink")
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
@@ -122,7 +135,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@parametrize_transport("IB", "NVLink")
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
|
||||
@@ -174,6 +187,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@skipif_ib
|
||||
def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
group, connections = create_and_connect(mpi_group, "IB")
|
||||
|
||||
@@ -262,7 +276,7 @@ class MscclppKernel:
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
@parametrize_transport("NVLink", "IB")
|
||||
def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
|
||||
def signal(semaphores):
|
||||
for rank in semaphores:
|
||||
@@ -295,7 +309,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup):
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10]])
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
|
||||
group, connections = create_and_connect(mpi_group, "NVLink")
|
||||
@@ -344,7 +358,7 @@ def test_fifo(
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@parametrize_transport("IB", "NVLink")
|
||||
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
|
||||
@@ -393,7 +407,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
@parametrize_transport("NVLink", "IB")
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
|
||||
Reference in New Issue
Block a user