Properly setting up the device in Ethernet Connection (#527)

When we create the thread to receive messages in the Ethernet
Connection, it resets the Device ID, causing faults in the Ethernet
Connection unit tests.


![image](https://github.com/user-attachments/assets/ba609c16-0f52-4624-807a-5ad776a0c18d)

This PR aims to properly set up the device when the thread is created.

---------

Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
Caio Rocha
2025-05-19 10:05:45 -07:00
committed by GitHub
parent a18e91cee4
commit 29c3af2ac6
2 changed files with 16 additions and 8 deletions

View File

@@ -325,15 +325,17 @@ def test_nvls_connection(mpi_group: MpiGroup):
pytest.skip("cannot use nvls for cross node")
group = mscclpp_comm.CommGroup(mpi_group.comm)
all_ranks = list(range(group.nranks))
endpoint = EndpointConfig(Transport.Nvls, 2**22)
nvls_connection = group.make_connection(all_ranks, endpoint)
mem_handle1 = nvls_connection.allocate_bind_memory(2**21)
mem_handle2 = nvls_connection.allocate_bind_memory(2**21)
nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
memory1 = GpuBuffer(2**29, cp.int8)
memory2 = GpuBuffer(2**29, cp.int8)
memory3 = GpuBuffer(2**29, cp.int8)
mem_handle1 = nvls_connection.bind_allocated_memory(memory1.data.ptr, memory1.data.mem.size)
mem_handle2 = nvls_connection.bind_allocated_memory(memory2.data.ptr, memory2.data.mem.size)
with pytest.raises(Exception):
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
mem_handle3 = nvls_connection.bind_allocated_memory(memory3.data.ptr, memory3.data.mem.size)
# the memory is freed on the destructor of mem_handle2
mem_handle2 = None
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
mem_handle3 = nvls_connection.bind_allocated_memory(memory3.data.ptr, memory3.data.mem.size)
class MscclppKernel:
@@ -610,8 +612,9 @@ def test_port_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packe
@pytest.mark.skipif(is_nvls_supported() is False, reason="NVLS is not supported")
def test_nvls(mpi_group: MpiGroup):
group, nvls_connection = create_group_and_connection(mpi_group, "NVLS")
memory = GpuBuffer(2**21, dtype=cp.int8)
nbytes = 2**21
mem_handle = nvls_connection.allocate_bind_memory(nbytes)
mem_handle = nvls_connection.bind_allocated_memory(memory.data.ptr, memory.data.mem.size)
nvlinks_connections = create_connection(group, "NVLink")
semaphores = group.make_semaphore(nvlinks_connections, MemoryDevice2DeviceSemaphore)