mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-04 21:51:32 +00:00
Release GIL for Python APIs with wait (#190)
This commit is contained in:
@@ -3,13 +3,22 @@
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
import threading
|
||||
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
import netifaces as ni
|
||||
import pytest
|
||||
|
||||
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport
|
||||
from mscclpp import (
|
||||
TcpBootstrap,
|
||||
Fifo,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
ProxyService,
|
||||
SmDevice2DeviceSemaphore,
|
||||
Transport,
|
||||
)
|
||||
from ._cpp import _ext
|
||||
from .mscclpp_group import MscclppGroup
|
||||
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
@@ -63,6 +72,50 @@ def test_group_with_ip(mpi_group: MpiGroup, ifIpPortTrio: str):
|
||||
assert np.array_equal(memory, memory_expected)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_bootstrap_init_gil_release(mpi_group: MpiGroup):
|
||||
bootstrap = TcpBootstrap.create(mpi_group.comm.rank, mpi_group.comm.size)
|
||||
uniq_id = None
|
||||
if mpi_group.comm.rank == 0:
|
||||
# similar to NCCL's unique id
|
||||
uniq_id = bootstrap.create_unique_id()
|
||||
uniq_id_global = mpi_group.comm.bcast(uniq_id, 0)
|
||||
|
||||
if mpi_group.comm.rank == 0:
|
||||
# rank 0 never initializes the bootstrap, making other ranks block
|
||||
pass
|
||||
else:
|
||||
check_list = []
|
||||
|
||||
def check_target():
|
||||
check_list.append("this thread could run.")
|
||||
|
||||
def init_target():
|
||||
try:
|
||||
# expected to raise a timeout after 3 seconds
|
||||
bootstrap.initialize(uniq_id_global, 3)
|
||||
except:
|
||||
pass
|
||||
|
||||
init_thread = threading.Thread(target=init_target)
|
||||
check_thread = threading.Thread(target=check_target)
|
||||
init_thread.start()
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
# check that the check thread is not blocked
|
||||
s = time.time()
|
||||
check_thread.start()
|
||||
check_thread.join()
|
||||
e = time.time()
|
||||
assert e - s < 0.1
|
||||
assert len(check_list) == 1
|
||||
|
||||
init_thread.join()
|
||||
|
||||
mpi_group.comm.barrier()
|
||||
|
||||
|
||||
def create_and_connect(mpi_group: MpiGroup, transport: str):
|
||||
if transport == "NVLink" and all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvlink for cross node")
|
||||
@@ -186,6 +239,33 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
group.barrier()
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
group, connections = create_and_connect(mpi_group, "IB")
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
|
||||
def target_wait(sems, conns):
|
||||
for rank in conns:
|
||||
sems[rank].wait(-1)
|
||||
|
||||
def target_signal(sems, conns):
|
||||
# sleep 1 sec to let target_wait() starts a bit earlier
|
||||
time.sleep(1)
|
||||
# if wait() doesn't release GIL, this will block forever
|
||||
for rank in conns:
|
||||
sems[rank].signal()
|
||||
|
||||
wait_thread = threading.Thread(target=target_wait, args=(semaphores, connections))
|
||||
signal_thread = threading.Thread(target=target_signal, args=(semaphores, connections))
|
||||
wait_thread.start()
|
||||
signal_thread.start()
|
||||
signal_thread.join()
|
||||
wait_thread.join()
|
||||
|
||||
group.barrier()
|
||||
|
||||
|
||||
class MscclppKernel:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user