Files
mscclpp/python/test/mscclpp_mpi.py
Binyang Li eda74a7f29 Add handle cache for AMD platform (#698)
Introduce handle cache for AMD platform.
Avoid reaching handle limitation if we open too much IPC handles

For nvidia, we don't need this feature since nvidia will count the
handle reference internally and reuse the same handle if already be
opened

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Binyang2014 <9415966+Binyang2014@users.noreply.github.com>
Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
2025-12-21 18:39:12 -08:00

62 lines
1.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from mpi4py import MPI
import pytest
N_GPUS_PER_NODE = 8
logging.basicConfig(level=logging.INFO)
_mpi_group_cache = {}
class MpiGroup:
def __init__(self, ranks: list = []):
world_group = MPI.COMM_WORLD.group
if len(ranks) == 0:
self.comm = MPI.COMM_WORLD
else:
group = world_group.Incl(ranks)
self.comm = MPI.COMM_WORLD.Create(group)
def __del__(self):
if self.comm != MPI.COMM_NULL and MPI.Is_initialized() and not MPI.Is_finalized():
self.comm.Free()
@pytest.fixture
def mpi_group(request: pytest.FixtureRequest):
MPI.COMM_WORLD.barrier()
mpi_group_obj = request.param
should_skip = mpi_group_obj.comm == MPI.COMM_NULL
try:
if should_skip:
pytest.skip(f"Skip for rank {MPI.COMM_WORLD.rank}")
yield request.param
finally:
if MPI.Is_initialized() and not MPI.Is_finalized():
MPI.COMM_WORLD.barrier()
def parametrize_mpi_groups(*tuples: tuple):
def decorator(func):
mpi_groups = []
for group_size in list(tuples):
if MPI.COMM_WORLD.size < group_size:
logging.warning(f"MPI.COMM_WORLD.size < {group_size}, skip")
continue
ranks = list(range(group_size))
ranks_key = tuple(ranks)
if ranks_key not in _mpi_group_cache:
_mpi_group_cache[ranks_key] = MpiGroup(ranks)
mpi_groups.append(_mpi_group_cache[ranks_key])
return pytest.mark.parametrize("mpi_group", mpi_groups, indirect=True)(func)
return decorator