This commit is contained in:
Crutcher Dunnavant
2023-04-06 00:15:08 +00:00
committed by Crutcher Dunnavant
parent 7753c38eb1
commit e65def8657
4 changed files with 45 additions and 17 deletions

View File

@@ -3,6 +3,8 @@
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cstdio>
#include <cstring>
@@ -69,6 +71,15 @@ void checkResult(
}
}
#define RETRY(C, ...) \
{ \
mscclppResult_t res; \
do { \
res = (C); \
} while (res == mscclppInProgress); \
checkResult(res, __VA_ARGS__); \
}
// Maybe return the value, maybe throw an exception.
template <typename Val, typename... Args>
Val maybe(
@@ -136,6 +147,11 @@ NB_MODULE(_py_mscclpp, m) {
mscclppSetLogHandler(mscclppDefaultLogHandler);
});
m.def("_setup", []() {
int device;
cudaGetDevice(&device);
});
nb::enum_<mscclppTransport_t>(m, "TransportType")
.value("P2P", mscclppTransport_t::mscclppTransportP2P)
.value("SHM", mscclppTransport_t::mscclppTransportSHM)
@@ -228,7 +244,7 @@ NB_MODULE(_py_mscclpp, m) {
uint64_t local_buff,
uint64_t buff_size,
mscclppTransport_t transport_type) -> void {
checkResult(
RETRY(
mscclppConnect(
self._handle,
remote_rank,
@@ -236,7 +252,7 @@ NB_MODULE(_py_mscclpp, m) {
reinterpret_cast<void*>(local_buff),
buff_size,
transport_type,
0 // ibDev
NULL // ibDev
),
"Connect failed");
},
@@ -249,12 +265,10 @@ NB_MODULE(_py_mscclpp, m) {
"Attach a local buffer to a remote connection.")
.def(
"connection_setup",
[](_Comm& comm) {
[](_Comm& comm) -> void {
comm.check_open();
return maybe(
mscclppConnectionSetup(comm._handle),
true,
"Failed to setup MSCCLPP connection");
RETRY(mscclppConnectionSetup(comm._handle),
"Failed to setup MSCCLPP connection");
},
nb::call_guard<nb::gil_scoped_release>(),
"Run connection setup for MSCCLPP.")

View File

@@ -10,6 +10,8 @@ logger = logging.getLogger(__file__)
from . import _py_mscclpp
_py_mscclpp._setup()
__all__ = (
"Comm",
"MscclppUniqueId",
@@ -46,7 +48,6 @@ MSCCLPP_LOG_LEVELS: set[str] = {
"TRACE",
}
def _setup_logging(level: str = "INFO"):
"""Setup log hooks for the C library."""
level = level.upper()

View File

@@ -74,9 +74,11 @@ class CommsTest(unittest.TestCase):
try:
f.result()
except subprocess.CalledProcessError as e:
errors.append(e.output)
errors.append((rank, e.output))
if errors:
raise AssertionError(
"\n\n".join(e.decode("utf-8", errors="ignore") for e in errors)
)
parts = []
for rank, content in errors:
parts.append(f"[rank {rank}]: " + content.decode('utf-8', errors='ignore'))
raise AssertionError("\n\n".join(parts))

View File

@@ -6,6 +6,7 @@ import hamcrest
import torch
import mscclpp
import time
@dataclass
@@ -74,20 +75,30 @@ def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm):
buf = torch.zeros(
[options.world_size], dtype=torch.int64, device="cuda"
).contiguous()
rank = options.rank
buf = torch.zeros([options.world_size], dtype=torch.int64)
buf[rank] = 42 + rank
buf = buf.cuda().contiguous()
tag = 0
remote_rank = (options.rank + 1) % options.world_size
if rank:
remote_rank = 0
else:
remote_rank = 1
comm.connect(
remote_rank,
tag,
buf.data_ptr(),
buf.element_size() * buf.numel(),
mscclpp._py_mscclpp.TransportType.P2P,
mscclpp.TransportType.P2P,
)
torch.cuda.synchronize()
# time.sleep(3)
comm.connection_setup()