mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
New semaphore constructors (#559)
More intuitive interfaces for creating semaphores and channels. Also allows channel construction using third-party bootstrappers directly without overriding MSCCL++ Bootstrap.
This commit is contained in:
@@ -28,6 +28,8 @@ from mscclpp import (
|
||||
is_nvls_supported,
|
||||
npkit,
|
||||
env,
|
||||
Device,
|
||||
DeviceType,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
@@ -280,7 +282,13 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
group, connections = create_group_and_connection(mpi_group, "IB")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
tran = group.my_ib_device(group.my_rank % 8)
|
||||
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
|
||||
remote_nghrs = list(range(group.nranks))
|
||||
remote_nghrs.remove(group.my_rank)
|
||||
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
|
||||
connections = {rank: conn.get() for rank, conn in connections.items()}
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
for rank in connections:
|
||||
@@ -293,7 +301,13 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
group, connections = create_group_and_connection(mpi_group, "IB")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
tran = group.my_ib_device(group.my_rank % 8)
|
||||
endpoint = EndpointConfig(tran, Device(DeviceType.CPU))
|
||||
remote_nghrs = list(range(group.nranks))
|
||||
remote_nghrs.remove(group.my_rank)
|
||||
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
|
||||
connections = {rank: conn.get() for rank, conn in connections.items()}
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user