mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 23:34:49 +00:00
Fix NVLS support (#258)
* Do not compile nvls_test with ROCm * Fix multi-node tests
This commit is contained in:
@@ -86,15 +86,18 @@ class CommGroup:
|
||||
) -> dict[int, Connection]:
|
||||
if type(endpoints) is Transport:
|
||||
endpoints = EndpointConfig(endpoints)
|
||||
if endpoints.transport == Transport.Nvls:
|
||||
return self.communicator.connct_nvls_collective(all_ranks, endpoints)
|
||||
elif type(endpoints) is dict:
|
||||
endpoints = {k: EndpointConfig(v) if type(v) is Transport else v for k, v in endpoints.items()}
|
||||
connections = {}
|
||||
for rank in all_ranks:
|
||||
if type(endpoints) is dict:
|
||||
endpoint = endpoints[rank]
|
||||
else:
|
||||
endpoint = endpoints
|
||||
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
|
||||
if endpoint.transport == Transport.Nvls:
|
||||
connections[rank] = self.communicator.connct_nvls_collective(all_ranks, endpoint)
|
||||
else:
|
||||
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
|
||||
self.communicator.setup()
|
||||
connections = {rank: connections[rank].get() for rank in connections}
|
||||
return connections
|
||||
|
||||
Reference in New Issue
Block a user