mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-03-27 18:47:49 +00:00
Introduce handle cache for AMD platform. Avoid reaching handle limitation if we open too much IPC handles For nvidia, we don't need this feature since nvidia will count the handle reference internally and reuse the same handle if already be opened --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Binyang2014 <9415966+Binyang2014@users.noreply.github.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
62 lines
1.7 KiB
Python
62 lines
1.7 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import logging
|
|
|
|
from mpi4py import MPI
|
|
import pytest
|
|
|
|
N_GPUS_PER_NODE = 8
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
_mpi_group_cache = {}
|
|
|
|
|
|
class MpiGroup:
|
|
def __init__(self, ranks: list = []):
|
|
world_group = MPI.COMM_WORLD.group
|
|
if len(ranks) == 0:
|
|
self.comm = MPI.COMM_WORLD
|
|
else:
|
|
group = world_group.Incl(ranks)
|
|
self.comm = MPI.COMM_WORLD.Create(group)
|
|
|
|
def __del__(self):
|
|
if self.comm != MPI.COMM_NULL and MPI.Is_initialized() and not MPI.Is_finalized():
|
|
self.comm.Free()
|
|
|
|
|
|
@pytest.fixture
|
|
def mpi_group(request: pytest.FixtureRequest):
|
|
MPI.COMM_WORLD.barrier()
|
|
|
|
mpi_group_obj = request.param
|
|
should_skip = mpi_group_obj.comm == MPI.COMM_NULL
|
|
|
|
try:
|
|
if should_skip:
|
|
pytest.skip(f"Skip for rank {MPI.COMM_WORLD.rank}")
|
|
yield request.param
|
|
finally:
|
|
if MPI.Is_initialized() and not MPI.Is_finalized():
|
|
MPI.COMM_WORLD.barrier()
|
|
|
|
|
|
def parametrize_mpi_groups(*tuples: tuple):
|
|
def decorator(func):
|
|
mpi_groups = []
|
|
for group_size in list(tuples):
|
|
if MPI.COMM_WORLD.size < group_size:
|
|
logging.warning(f"MPI.COMM_WORLD.size < {group_size}, skip")
|
|
continue
|
|
ranks = list(range(group_size))
|
|
ranks_key = tuple(ranks)
|
|
if ranks_key not in _mpi_group_cache:
|
|
_mpi_group_cache[ranks_key] = MpiGroup(ranks)
|
|
|
|
mpi_groups.append(_mpi_group_cache[ranks_key])
|
|
return pytest.mark.parametrize("mpi_group", mpi_groups, indirect=True)(func)
|
|
|
|
return decorator
|