mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 14:29:13 +00:00
Properly setting up the device in Ethernet Connection (#527)
When we create the thread to receive messages in the Ethernet Connection, it resets the Device ID, causing faults in the Ethernet Connection unit tests.  This PR aims to properly set up the device when the thread is created. --------- Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
@@ -325,15 +325,17 @@ def test_nvls_connection(mpi_group: MpiGroup):
|
||||
pytest.skip("cannot use nvls for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
all_ranks = list(range(group.nranks))
|
||||
endpoint = EndpointConfig(Transport.Nvls, 2**22)
|
||||
nvls_connection = group.make_connection(all_ranks, endpoint)
|
||||
mem_handle1 = nvls_connection.allocate_bind_memory(2**21)
|
||||
mem_handle2 = nvls_connection.allocate_bind_memory(2**21)
|
||||
nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
|
||||
memory1 = GpuBuffer(2**29, cp.int8)
|
||||
memory2 = GpuBuffer(2**29, cp.int8)
|
||||
memory3 = GpuBuffer(2**29, cp.int8)
|
||||
mem_handle1 = nvls_connection.bind_allocated_memory(memory1.data.ptr, memory1.data.mem.size)
|
||||
mem_handle2 = nvls_connection.bind_allocated_memory(memory2.data.ptr, memory2.data.mem.size)
|
||||
with pytest.raises(Exception):
|
||||
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
|
||||
mem_handle3 = nvls_connection.bind_allocated_memory(memory3.data.ptr, memory3.data.mem.size)
|
||||
# the memory is freed on the destructor of mem_handle2
|
||||
mem_handle2 = None
|
||||
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
|
||||
mem_handle3 = nvls_connection.bind_allocated_memory(memory3.data.ptr, memory3.data.mem.size)
|
||||
|
||||
|
||||
class MscclppKernel:
|
||||
@@ -610,8 +612,9 @@ def test_port_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packe
|
||||
@pytest.mark.skipif(is_nvls_supported() is False, reason="NVLS is not supported")
|
||||
def test_nvls(mpi_group: MpiGroup):
|
||||
group, nvls_connection = create_group_and_connection(mpi_group, "NVLS")
|
||||
memory = GpuBuffer(2**21, dtype=cp.int8)
|
||||
nbytes = 2**21
|
||||
mem_handle = nvls_connection.allocate_bind_memory(nbytes)
|
||||
mem_handle = nvls_connection.bind_allocated_memory(memory.data.ptr, memory.data.mem.size)
|
||||
|
||||
nvlinks_connections = create_connection(group, "NVLink")
|
||||
semaphores = group.make_semaphore(nvlinks_connections, MemoryDevice2DeviceSemaphore)
|
||||
|
||||
Reference in New Issue
Block a user