Files
mscclpp/python/test/mscclpp_mpi.py
Changho Hwang 060fda12e6 mscclpp-test in Python (#204)
Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: Saeed Maleki <saemal@microsoft.com>
Co-authored-by: Esha Choukse <eschouks@microsoft.com>
2023-11-16 12:45:25 +08:00

76 lines
1.8 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import atexit
import logging
import cupy as cp
import mpi4py
mpi4py.rc.initialize = False
mpi4py.rc.finalize = False
from mpi4py import MPI
import pytest
N_GPUS_PER_NODE = 8
logging.basicConfig(level=logging.INFO)
def init_mpi():
if not MPI.Is_initialized():
MPI.Init()
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
N_GPUS_PER_NODE = shm_comm.size
shm_comm.Free()
cp.cuda.Device(MPI.COMM_WORLD.rank % N_GPUS_PER_NODE).use()
# Define a function to finalize MPI
def finalize_mpi():
if MPI.Is_initialized():
MPI.Finalize()
# Register the function to be called on exit
atexit.register(finalize_mpi)
class MpiGroup:
def __init__(self, ranks: list = []):
world_group = MPI.COMM_WORLD.group
if len(ranks) == 0:
self.comm = MPI.COMM_WORLD
else:
group = world_group.Incl(ranks)
self.comm = MPI.COMM_WORLD.Create(group)
@pytest.fixture
def mpi_group(request: pytest.FixtureRequest):
MPI.COMM_WORLD.barrier()
if request.param is None:
pytest.skip(f"Skip for rank {MPI.COMM_WORLD.rank}")
yield request.param
def parametrize_mpi_groups(*tuples: tuple):
def decorator(func):
mpi_groups = []
for group_size in list(tuples):
if MPI.COMM_WORLD.size < group_size:
logging.warning(f"MPI.COMM_WORLD.size < {group_size}, skip")
continue
mpi_group = MpiGroup(list(range(group_size)))
if mpi_group.comm == MPI.COMM_NULL:
mpi_groups.append(None)
else:
mpi_groups.append(mpi_group)
return pytest.mark.parametrize("mpi_group", mpi_groups, indirect=True)(func)
return decorator
init_mpi()