mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-04 05:31:27 +00:00
Co-authored-by: Binyang Li <binyli@microsoft.com> Co-authored-by: Saeed Maleki <saemal@microsoft.com> Co-authored-by: Esha Choukse <eschouks@microsoft.com>
76 lines
1.8 KiB
Python
76 lines
1.8 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
import atexit
|
|
import logging
|
|
|
|
import cupy as cp
|
|
import mpi4py
|
|
|
|
mpi4py.rc.initialize = False
|
|
mpi4py.rc.finalize = False
|
|
|
|
from mpi4py import MPI
|
|
import pytest
|
|
|
|
N_GPUS_PER_NODE = 8
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
def init_mpi():
|
|
if not MPI.Is_initialized():
|
|
MPI.Init()
|
|
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
|
|
N_GPUS_PER_NODE = shm_comm.size
|
|
shm_comm.Free()
|
|
cp.cuda.Device(MPI.COMM_WORLD.rank % N_GPUS_PER_NODE).use()
|
|
|
|
|
|
# Define a function to finalize MPI
|
|
def finalize_mpi():
|
|
if MPI.Is_initialized():
|
|
MPI.Finalize()
|
|
|
|
|
|
# Register the function to be called on exit
|
|
atexit.register(finalize_mpi)
|
|
|
|
|
|
class MpiGroup:
|
|
def __init__(self, ranks: list = []):
|
|
world_group = MPI.COMM_WORLD.group
|
|
if len(ranks) == 0:
|
|
self.comm = MPI.COMM_WORLD
|
|
else:
|
|
group = world_group.Incl(ranks)
|
|
self.comm = MPI.COMM_WORLD.Create(group)
|
|
|
|
|
|
@pytest.fixture
|
|
def mpi_group(request: pytest.FixtureRequest):
|
|
MPI.COMM_WORLD.barrier()
|
|
if request.param is None:
|
|
pytest.skip(f"Skip for rank {MPI.COMM_WORLD.rank}")
|
|
yield request.param
|
|
|
|
|
|
def parametrize_mpi_groups(*tuples: tuple):
|
|
def decorator(func):
|
|
mpi_groups = []
|
|
for group_size in list(tuples):
|
|
if MPI.COMM_WORLD.size < group_size:
|
|
logging.warning(f"MPI.COMM_WORLD.size < {group_size}, skip")
|
|
continue
|
|
mpi_group = MpiGroup(list(range(group_size)))
|
|
if mpi_group.comm == MPI.COMM_NULL:
|
|
mpi_groups.append(None)
|
|
else:
|
|
mpi_groups.append(mpi_group)
|
|
return pytest.mark.parametrize("mpi_group", mpi_groups, indirect=True)(func)
|
|
|
|
return decorator
|
|
|
|
|
|
init_mpi()
|