Release GIL for Python APIs with wait (#190)

This commit is contained in:
Changho Hwang
2023-11-14 21:11:01 +08:00
committed by GitHub
parent 3521fb0280
commit 4cdb100265
3 changed files with 88 additions and 7 deletions

View File

@@ -3,13 +3,22 @@
from concurrent.futures import ThreadPoolExecutor
import time
import threading
import cupy as cp
import numpy as np
import netifaces as ni
import pytest
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport
from mscclpp import (
TcpBootstrap,
Fifo,
Host2DeviceSemaphore,
Host2HostSemaphore,
ProxyService,
SmDevice2DeviceSemaphore,
Transport,
)
from ._cpp import _ext
from .mscclpp_group import MscclppGroup
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
@@ -63,6 +72,50 @@ def test_group_with_ip(mpi_group: MpiGroup, ifIpPortTrio: str):
assert np.array_equal(memory, memory_expected)
@parametrize_mpi_groups(2, 4, 8, 16)
def test_bootstrap_init_gil_release(mpi_group: MpiGroup):
bootstrap = TcpBootstrap.create(mpi_group.comm.rank, mpi_group.comm.size)
uniq_id = None
if mpi_group.comm.rank == 0:
# similar to NCCL's unique id
uniq_id = bootstrap.create_unique_id()
uniq_id_global = mpi_group.comm.bcast(uniq_id, 0)
if mpi_group.comm.rank == 0:
# rank 0 never initializes the bootstrap, making other ranks block
pass
else:
check_list = []
def check_target():
check_list.append("this thread could run.")
def init_target():
try:
# expected to raise a timeout after 3 seconds
bootstrap.initialize(uniq_id_global, 3)
except:
pass
init_thread = threading.Thread(target=init_target)
check_thread = threading.Thread(target=check_target)
init_thread.start()
time.sleep(0.1)
# check that the check thread is not blocked
s = time.time()
check_thread.start()
check_thread.join()
e = time.time()
assert e - s < 0.1
assert len(check_list) == 1
init_thread.join()
mpi_group.comm.barrier()
def create_and_connect(mpi_group: MpiGroup, transport: str):
if transport == "NVLink" and all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("cannot use nvlink for cross node")
@@ -186,6 +239,33 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
group.barrier()
@parametrize_mpi_groups(2, 4, 8, 16)
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
group, connections = create_and_connect(mpi_group, "IB")
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
def target_wait(sems, conns):
for rank in conns:
sems[rank].wait(-1)
def target_signal(sems, conns):
# sleep 1 sec to let target_wait() starts a bit earlier
time.sleep(1)
# if wait() doesn't release GIL, this will block forever
for rank in conns:
sems[rank].signal()
wait_thread = threading.Thread(target=target_wait, args=(semaphores, connections))
signal_thread = threading.Thread(target=target_signal, args=(semaphores, connections))
wait_thread.start()
signal_thread.start()
signal_thread.join()
wait_thread.join()
group.barrier()
class MscclppKernel:
def __init__(
self,