mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-07 00:05:19 +00:00
Release GIL for Python APIs with wait (#190)
This commit is contained in:
@@ -64,10 +64,10 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("nRanks"))
|
||||
.def("create_unique_id", &TcpBootstrap::createUniqueId)
|
||||
.def("get_unique_id", &TcpBootstrap::getUniqueId)
|
||||
.def("initialize", (void (TcpBootstrap::*)(UniqueId, int64_t)) & TcpBootstrap::initialize, nb::arg("uniqueId"),
|
||||
nb::arg("timeoutSec") = 30)
|
||||
.def("initialize", (void (TcpBootstrap::*)(const std::string&, int64_t)) & TcpBootstrap::initialize,
|
||||
nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30);
|
||||
.def("initialize", static_cast<void (TcpBootstrap::*)(UniqueId, int64_t)>(&TcpBootstrap::initialize),
|
||||
nb::call_guard<nb::gil_scoped_release>(), nb::arg("uniqueId"), nb::arg("timeoutSec") = 30)
|
||||
.def("initialize", static_cast<void (TcpBootstrap::*)(const std::string&, int64_t)>(&TcpBootstrap::initialize),
|
||||
nb::call_guard<nb::gil_scoped_release>(), nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30);
|
||||
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
@@ -120,7 +120,7 @@ void register_core(nb::module_& m) {
|
||||
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
|
||||
},
|
||||
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
|
||||
.def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7)
|
||||
.def("flush", &Connection::flush, nb::call_guard<nb::gil_scoped_release>(), nb::arg("timeoutUsec") = (int64_t)3e7)
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport);
|
||||
|
||||
|
||||
@@ -30,7 +30,8 @@ void register_semaphore(nb::module_& m) {
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
.def("signal", &Host2HostSemaphore::signal)
|
||||
.def("poll", &Host2HostSemaphore::poll)
|
||||
.def("wait", &Host2HostSemaphore::wait, nb::arg("max_spin_count") = 10000000);
|
||||
.def("wait", &Host2HostSemaphore::wait, nb::call_guard<nb::gil_scoped_release>(),
|
||||
nb::arg("max_spin_count") = 10000000);
|
||||
|
||||
nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
|
||||
smDevice2DeviceSemaphore
|
||||
|
||||
@@ -3,13 +3,22 @@
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
import threading
|
||||
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
import netifaces as ni
|
||||
import pytest
|
||||
|
||||
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport
|
||||
from mscclpp import (
|
||||
TcpBootstrap,
|
||||
Fifo,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
ProxyService,
|
||||
SmDevice2DeviceSemaphore,
|
||||
Transport,
|
||||
)
|
||||
from ._cpp import _ext
|
||||
from .mscclpp_group import MscclppGroup
|
||||
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
@@ -63,6 +72,50 @@ def test_group_with_ip(mpi_group: MpiGroup, ifIpPortTrio: str):
|
||||
assert np.array_equal(memory, memory_expected)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_bootstrap_init_gil_release(mpi_group: MpiGroup):
|
||||
bootstrap = TcpBootstrap.create(mpi_group.comm.rank, mpi_group.comm.size)
|
||||
uniq_id = None
|
||||
if mpi_group.comm.rank == 0:
|
||||
# similar to NCCL's unique id
|
||||
uniq_id = bootstrap.create_unique_id()
|
||||
uniq_id_global = mpi_group.comm.bcast(uniq_id, 0)
|
||||
|
||||
if mpi_group.comm.rank == 0:
|
||||
# rank 0 never initializes the bootstrap, making other ranks block
|
||||
pass
|
||||
else:
|
||||
check_list = []
|
||||
|
||||
def check_target():
|
||||
check_list.append("this thread could run.")
|
||||
|
||||
def init_target():
|
||||
try:
|
||||
# expected to raise a timeout after 3 seconds
|
||||
bootstrap.initialize(uniq_id_global, 3)
|
||||
except:
|
||||
pass
|
||||
|
||||
init_thread = threading.Thread(target=init_target)
|
||||
check_thread = threading.Thread(target=check_target)
|
||||
init_thread.start()
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
# check that the check thread is not blocked
|
||||
s = time.time()
|
||||
check_thread.start()
|
||||
check_thread.join()
|
||||
e = time.time()
|
||||
assert e - s < 0.1
|
||||
assert len(check_list) == 1
|
||||
|
||||
init_thread.join()
|
||||
|
||||
mpi_group.comm.barrier()
|
||||
|
||||
|
||||
def create_and_connect(mpi_group: MpiGroup, transport: str):
|
||||
if transport == "NVLink" and all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvlink for cross node")
|
||||
@@ -186,6 +239,33 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
group.barrier()
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
group, connections = create_and_connect(mpi_group, "IB")
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
|
||||
def target_wait(sems, conns):
|
||||
for rank in conns:
|
||||
sems[rank].wait(-1)
|
||||
|
||||
def target_signal(sems, conns):
|
||||
# sleep 1 sec to let target_wait() starts a bit earlier
|
||||
time.sleep(1)
|
||||
# if wait() doesn't release GIL, this will block forever
|
||||
for rank in conns:
|
||||
sems[rank].signal()
|
||||
|
||||
wait_thread = threading.Thread(target=target_wait, args=(semaphores, connections))
|
||||
signal_thread = threading.Thread(target=target_signal, args=(semaphores, connections))
|
||||
wait_thread.start()
|
||||
signal_thread.start()
|
||||
signal_thread.join()
|
||||
wait_thread.join()
|
||||
|
||||
group.barrier()
|
||||
|
||||
|
||||
class MscclppKernel:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user