mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
bug
This commit is contained in:
committed by
Crutcher Dunnavant
parent
7753c38eb1
commit
e65def8657
@@ -3,6 +3,8 @@
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
@@ -69,6 +71,15 @@ void checkResult(
|
||||
}
|
||||
}
|
||||
|
||||
#define RETRY(C, ...) \
|
||||
{ \
|
||||
mscclppResult_t res; \
|
||||
do { \
|
||||
res = (C); \
|
||||
} while (res == mscclppInProgress); \
|
||||
checkResult(res, __VA_ARGS__); \
|
||||
}
|
||||
|
||||
// Maybe return the value, maybe throw an exception.
|
||||
template <typename Val, typename... Args>
|
||||
Val maybe(
|
||||
@@ -136,6 +147,11 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
mscclppSetLogHandler(mscclppDefaultLogHandler);
|
||||
});
|
||||
|
||||
m.def("_setup", []() {
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
});
|
||||
|
||||
nb::enum_<mscclppTransport_t>(m, "TransportType")
|
||||
.value("P2P", mscclppTransport_t::mscclppTransportP2P)
|
||||
.value("SHM", mscclppTransport_t::mscclppTransportSHM)
|
||||
@@ -228,7 +244,7 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
uint64_t local_buff,
|
||||
uint64_t buff_size,
|
||||
mscclppTransport_t transport_type) -> void {
|
||||
checkResult(
|
||||
RETRY(
|
||||
mscclppConnect(
|
||||
self._handle,
|
||||
remote_rank,
|
||||
@@ -236,7 +252,7 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
reinterpret_cast<void*>(local_buff),
|
||||
buff_size,
|
||||
transport_type,
|
||||
0 // ibDev
|
||||
NULL // ibDev
|
||||
),
|
||||
"Connect failed");
|
||||
},
|
||||
@@ -249,12 +265,10 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"Attach a local buffer to a remote connection.")
|
||||
.def(
|
||||
"connection_setup",
|
||||
[](_Comm& comm) {
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
return maybe(
|
||||
mscclppConnectionSetup(comm._handle),
|
||||
true,
|
||||
"Failed to setup MSCCLPP connection");
|
||||
RETRY(mscclppConnectionSetup(comm._handle),
|
||||
"Failed to setup MSCCLPP connection");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Run connection setup for MSCCLPP.")
|
||||
|
||||
@@ -10,6 +10,8 @@ logger = logging.getLogger(__file__)
|
||||
|
||||
from . import _py_mscclpp
|
||||
|
||||
_py_mscclpp._setup()
|
||||
|
||||
__all__ = (
|
||||
"Comm",
|
||||
"MscclppUniqueId",
|
||||
@@ -46,7 +48,6 @@ MSCCLPP_LOG_LEVELS: set[str] = {
|
||||
"TRACE",
|
||||
}
|
||||
|
||||
|
||||
def _setup_logging(level: str = "INFO"):
|
||||
"""Setup log hooks for the C library."""
|
||||
level = level.upper()
|
||||
|
||||
@@ -74,9 +74,11 @@ class CommsTest(unittest.TestCase):
|
||||
try:
|
||||
f.result()
|
||||
except subprocess.CalledProcessError as e:
|
||||
errors.append(e.output)
|
||||
errors.append((rank, e.output))
|
||||
|
||||
if errors:
|
||||
raise AssertionError(
|
||||
"\n\n".join(e.decode("utf-8", errors="ignore") for e in errors)
|
||||
)
|
||||
parts = []
|
||||
for rank, content in errors:
|
||||
parts.append(f"[rank {rank}]: " + content.decode('utf-8', errors='ignore'))
|
||||
|
||||
raise AssertionError("\n\n".join(parts))
|
||||
|
||||
@@ -6,6 +6,7 @@ import hamcrest
|
||||
import torch
|
||||
|
||||
import mscclpp
|
||||
import time
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -74,20 +75,30 @@ def _test_allgather_pickle(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
|
||||
|
||||
def _test_allgather_torch(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
buf = torch.zeros(
|
||||
[options.world_size], dtype=torch.int64, device="cuda"
|
||||
).contiguous()
|
||||
rank = options.rank
|
||||
|
||||
buf = torch.zeros([options.world_size], dtype=torch.int64)
|
||||
buf[rank] = 42 + rank
|
||||
buf = buf.cuda().contiguous()
|
||||
|
||||
tag = 0
|
||||
remote_rank = (options.rank + 1) % options.world_size
|
||||
|
||||
if rank:
|
||||
remote_rank = 0
|
||||
else:
|
||||
remote_rank = 1
|
||||
|
||||
comm.connect(
|
||||
remote_rank,
|
||||
tag,
|
||||
buf.data_ptr(),
|
||||
buf.element_size() * buf.numel(),
|
||||
mscclpp._py_mscclpp.TransportType.P2P,
|
||||
mscclpp.TransportType.P2P,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# time.sleep(3)
|
||||
|
||||
comm.connection_setup()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user