mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Properly setting up the device in Ethernet Connection (#527)
When we create the thread to receive messages in the Ethernet Connection, it resets the Device ID, causing faults in the Ethernet Connection unit tests.  This PR aims to properly set up the device when the thread is created. --------- Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
@@ -325,15 +325,17 @@ def test_nvls_connection(mpi_group: MpiGroup):
|
||||
pytest.skip("cannot use nvls for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
all_ranks = list(range(group.nranks))
|
||||
endpoint = EndpointConfig(Transport.Nvls, 2**22)
|
||||
nvls_connection = group.make_connection(all_ranks, endpoint)
|
||||
mem_handle1 = nvls_connection.allocate_bind_memory(2**21)
|
||||
mem_handle2 = nvls_connection.allocate_bind_memory(2**21)
|
||||
nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
|
||||
memory1 = GpuBuffer(2**29, cp.int8)
|
||||
memory2 = GpuBuffer(2**29, cp.int8)
|
||||
memory3 = GpuBuffer(2**29, cp.int8)
|
||||
mem_handle1 = nvls_connection.bind_allocated_memory(memory1.data.ptr, memory1.data.mem.size)
|
||||
mem_handle2 = nvls_connection.bind_allocated_memory(memory2.data.ptr, memory2.data.mem.size)
|
||||
with pytest.raises(Exception):
|
||||
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
|
||||
mem_handle3 = nvls_connection.bind_allocated_memory(memory3.data.ptr, memory3.data.mem.size)
|
||||
# the memory is freed on the destructor of mem_handle2
|
||||
mem_handle2 = None
|
||||
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
|
||||
mem_handle3 = nvls_connection.bind_allocated_memory(memory3.data.ptr, memory3.data.mem.size)
|
||||
|
||||
|
||||
class MscclppKernel:
|
||||
@@ -610,8 +612,9 @@ def test_port_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packe
|
||||
@pytest.mark.skipif(is_nvls_supported() is False, reason="NVLS is not supported")
|
||||
def test_nvls(mpi_group: MpiGroup):
|
||||
group, nvls_connection = create_group_and_connection(mpi_group, "NVLS")
|
||||
memory = GpuBuffer(2**21, dtype=cp.int8)
|
||||
nbytes = 2**21
|
||||
mem_handle = nvls_connection.allocate_bind_memory(nbytes)
|
||||
mem_handle = nvls_connection.bind_allocated_memory(memory.data.ptr, memory.data.mem.size)
|
||||
|
||||
nvlinks_connections = create_connection(group, "NVLink")
|
||||
semaphores = group.make_semaphore(nvlinks_connections, MemoryDevice2DeviceSemaphore)
|
||||
|
||||
@@ -273,7 +273,12 @@ EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEn
|
||||
t.join();
|
||||
|
||||
// Starting Thread to Receive Messages
|
||||
threadRecvMessages_ = std::thread(&EthernetConnection::recvMessages, this);
|
||||
int deviceId = -1;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId));
|
||||
threadRecvMessages_ = std::thread([deviceId, this]() {
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(deviceId));
|
||||
this->recvMessages();
|
||||
});
|
||||
|
||||
INFO(MSCCLPP_NET, "Ethernet connection created");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user