Fix NVLS support (#258)

* Do not compile nvls_test with ROCm
* Fix multi-node tests
This commit is contained in:
Changho Hwang
2024-02-06 15:24:13 -08:00
committed by GitHub
parent d34e097b40
commit 6a19b19ece
3 changed files with 21 additions and 8 deletions

View File

@@ -86,15 +86,18 @@ class CommGroup:
) -> dict[int, Connection]:
if type(endpoints) is Transport:
endpoints = EndpointConfig(endpoints)
if endpoints.transport == Transport.Nvls:
return self.communicator.connct_nvls_collective(all_ranks, endpoints)
elif type(endpoints) is dict:
endpoints = {k: EndpointConfig(v) if type(v) is Transport else v for k, v in endpoints.items()}
connections = {}
for rank in all_ranks:
if type(endpoints) is dict:
endpoint = endpoints[rank]
else:
endpoint = endpoints
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
if endpoint.transport == Transport.Nvls:
connections[rank] = self.communicator.connct_nvls_collective(all_ranks, endpoint)
else:
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
self.communicator.setup()
connections = {rank: connections[rank].get() for rank in connections}
return connections