More bindings and better test parametrization

This commit is contained in:
Olli Saarikivi
2023-10-18 23:03:28 +00:00
committed by Saeed Maleki
parent 629949498b
commit 69938b47aa
3 changed files with 29 additions and 8 deletions

View File

@@ -19,6 +19,9 @@ from ._mscclpp import (
Transport,
TransportFlags,
version,
get_ib_device_count,
get_ib_device_name,
get_ib_transport_by_device_name,
)
__version__ = version()

View File

@@ -166,6 +166,10 @@ void register_core(nb::module_& m) {
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf);
m.def("get_ib_device_count", &getIBDeviceCount);
m.def("get_ib_device_name", &getIBDeviceName, nb::arg("ib_transport"));
m.def("get_ib_transport_by_device_name", &getIBTransportByDeviceName, nb::arg("ib_device_name"));
}
NB_MODULE(_mscclpp, m) {

View File

@@ -9,7 +9,7 @@ import numpy as np
import netifaces as ni
import pytest
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport, get_ib_device_count
from ._cpp import _ext
from .mscclpp_group import MscclppGroup
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
@@ -17,6 +17,19 @@ from .utils import KernelBuilder, pack
ethernet_interface_name = "eth0"
skipif_ib = pytest.mark.skipif(get_ib_device_count() == 0, reason="no IB device")
def parametrize_transport(*transports: list):
def decorator(func):
params = []
for transport in transports:
if transport == "IB":
params.append(pytest.param(transport, marks=skipif_ib))
else:
params.append(transport)
return pytest.mark.parametrize("transport", params)(func)
return decorator
def all_ranks_on_the_same_node(mpi_group: MpiGroup):
if (ethernet_interface_name in ni.interfaces()) is False:
@@ -81,13 +94,13 @@ def create_and_connect(mpi_group: MpiGroup, transport: str):
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@parametrize_transport("IB", "NVLink")
def test_group_with_connections(mpi_group: MpiGroup, transport: str):
create_and_connect(mpi_group, transport)
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@parametrize_transport("IB", "NVLink")
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
group, connections = create_and_connect(mpi_group, transport)
@@ -122,7 +135,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@parametrize_transport("IB", "NVLink")
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
@@ -174,6 +187,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
@parametrize_mpi_groups(2, 4, 8, 16)
@skipif_ib
def test_h2h_semaphores(mpi_group: MpiGroup):
group, connections = create_and_connect(mpi_group, "IB")
@@ -262,7 +276,7 @@ class MscclppKernel:
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
@parametrize_transport("NVLink", "IB")
def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
def signal(semaphores):
for rank in semaphores:
@@ -295,7 +309,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup):
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("nelem", [2**i for i in [10]])
@pytest.mark.parametrize("use_packet", [False, True])
def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
group, connections = create_and_connect(mpi_group, "NVLink")
@@ -344,7 +358,7 @@ def test_fifo(
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@parametrize_transport("IB", "NVLink")
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
group, connections = create_and_connect(mpi_group, transport)
@@ -393,7 +407,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
@parametrize_transport("NVLink", "IB")
@pytest.mark.parametrize("use_packet", [False, True])
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
group, connections = create_and_connect(mpi_group, transport)