mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
More bindings and better test parametrization
This commit is contained in:
committed by
Saeed Maleki
parent
629949498b
commit
69938b47aa
@@ -19,6 +19,9 @@ from ._mscclpp import (
|
||||
Transport,
|
||||
TransportFlags,
|
||||
version,
|
||||
get_ib_device_count,
|
||||
get_ib_device_name,
|
||||
get_ib_transport_by_device_name,
|
||||
)
|
||||
|
||||
__version__ = version()
|
||||
|
||||
@@ -166,6 +166,10 @@ void register_core(nb::module_& m) {
|
||||
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf);
|
||||
|
||||
m.def("get_ib_device_count", &getIBDeviceCount);
|
||||
m.def("get_ib_device_name", &getIBDeviceName, nb::arg("ib_transport"));
|
||||
m.def("get_ib_transport_by_device_name", &getIBTransportByDeviceName, nb::arg("ib_device_name"));
|
||||
}
|
||||
|
||||
NB_MODULE(_mscclpp, m) {
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import netifaces as ni
|
||||
import pytest
|
||||
|
||||
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport
|
||||
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport, get_ib_device_count
|
||||
from ._cpp import _ext
|
||||
from .mscclpp_group import MscclppGroup
|
||||
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
@@ -17,6 +17,19 @@ from .utils import KernelBuilder, pack
|
||||
|
||||
ethernet_interface_name = "eth0"
|
||||
|
||||
skipif_ib = pytest.mark.skipif(get_ib_device_count() == 0, reason="no IB device")
|
||||
|
||||
def parametrize_transport(*transports: list):
|
||||
def decorator(func):
|
||||
params = []
|
||||
for transport in transports:
|
||||
if transport == "IB":
|
||||
params.append(pytest.param(transport, marks=skipif_ib))
|
||||
else:
|
||||
params.append(transport)
|
||||
return pytest.mark.parametrize("transport", params)(func)
|
||||
return decorator
|
||||
|
||||
|
||||
def all_ranks_on_the_same_node(mpi_group: MpiGroup):
|
||||
if (ethernet_interface_name in ni.interfaces()) is False:
|
||||
@@ -81,13 +94,13 @@ def create_and_connect(mpi_group: MpiGroup, transport: str):
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@parametrize_transport("IB", "NVLink")
|
||||
def test_group_with_connections(mpi_group: MpiGroup, transport: str):
|
||||
create_and_connect(mpi_group, transport)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@parametrize_transport("IB", "NVLink")
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
@@ -122,7 +135,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@parametrize_transport("IB", "NVLink")
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
|
||||
@@ -174,6 +187,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@skipif_ib
|
||||
def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
group, connections = create_and_connect(mpi_group, "IB")
|
||||
|
||||
@@ -262,7 +276,7 @@ class MscclppKernel:
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
@parametrize_transport("NVLink", "IB")
|
||||
def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
|
||||
def signal(semaphores):
|
||||
for rank in semaphores:
|
||||
@@ -295,7 +309,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup):
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10]])
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
|
||||
group, connections = create_and_connect(mpi_group, "NVLink")
|
||||
@@ -344,7 +358,7 @@ def test_fifo(
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@parametrize_transport("IB", "NVLink")
|
||||
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
|
||||
@@ -393,7 +407,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
@parametrize_transport("NVLink", "IB")
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
|
||||
Reference in New Issue
Block a user