mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
[python] switch to python setup.py build and wheels
This commit is contained in:
2
Makefile
2
Makefile
@@ -61,7 +61,7 @@ endif
|
||||
|
||||
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 --expt-extended-lambda -Xfatbin -compress-all
|
||||
# Use addprefix so that we can specify more than one path
|
||||
NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt -lcuda
|
||||
NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt
|
||||
|
||||
ifeq ($(DEBUG), 0)
|
||||
NVCUFLAGS += -O3
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <mscclpp.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
@@ -70,14 +71,14 @@ void checkResult(
|
||||
}
|
||||
}
|
||||
|
||||
#define RETRY(C, ...) \
|
||||
{ \
|
||||
mscclppResult_t res; \
|
||||
do { \
|
||||
res = (C); \
|
||||
} while (res == mscclppInProgress); \
|
||||
checkResult(res, __VA_ARGS__); \
|
||||
}
|
||||
#define RETRY(C, ...) \
|
||||
{ \
|
||||
mscclppResult_t res; \
|
||||
do { \
|
||||
res = (C); \
|
||||
} while (res == mscclppInProgress); \
|
||||
checkResult(res, __VA_ARGS__); \
|
||||
}
|
||||
|
||||
// Maybe return the value, maybe throw an exception.
|
||||
template <typename Val, typename... Args>
|
||||
@@ -97,11 +98,7 @@ struct _Comm {
|
||||
|
||||
public:
|
||||
_Comm(int rank, int world_size, mscclppComm_t handle)
|
||||
: _rank(rank),
|
||||
_world_size(world_size),
|
||||
_handle(handle),
|
||||
_is_open(true),
|
||||
_proxies_running(false) {}
|
||||
: _rank(rank), _world_size(world_size), _handle(handle), _is_open(true), _proxies_running(false) {}
|
||||
|
||||
~_Comm() { close(); }
|
||||
|
||||
@@ -109,8 +106,8 @@ struct _Comm {
|
||||
void close() {
|
||||
if (_is_open) {
|
||||
if (_proxies_running) {
|
||||
mscclppProxyStop(_handle);
|
||||
_proxies_running = false;
|
||||
mscclppProxyStop(_handle);
|
||||
_proxies_running = false;
|
||||
}
|
||||
checkResult(mscclppCommDestroy(_handle), "Failed to close comm channel");
|
||||
_handle = NULL;
|
||||
@@ -127,29 +124,6 @@ struct _Comm {
|
||||
}
|
||||
};
|
||||
|
||||
struct _P2PHandle {
|
||||
struct mscclppRegisteredMemoryP2P _rm;
|
||||
struct mscclppIbMr _local_ibmr;
|
||||
|
||||
_P2PHandle() : _rm({0}), _local_ibmr({0}) {}
|
||||
|
||||
_P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _local_ibmr({0}) {
|
||||
_rm = p2p;
|
||||
if (_rm.IbMr != nullptr) {
|
||||
_local_ibmr = *_rm.IbMr;
|
||||
_rm.IbMr = &_local_ibmr;
|
||||
}
|
||||
}
|
||||
|
||||
mscclppTransport_t transport() const {
|
||||
if (_rm.IbMr != nullptr) {
|
||||
return mscclppTransport_t::mscclppTransportIB;
|
||||
} else {
|
||||
return mscclppTransport_t::mscclppTransportP2P;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
nb::callable _log_callback;
|
||||
|
||||
void _LogHandler(const char* msg) {
|
||||
@@ -164,8 +138,6 @@ static const std::string DOC_MscclppUniqueId =
|
||||
|
||||
static const std::string DOC__Comm = "MSCCLPP Communications Handle";
|
||||
|
||||
static const std::string DOC__P2PHandle = "MSCCLPP P2P MR Handle";
|
||||
|
||||
NB_MODULE(_py_mscclpp, m) {
|
||||
m.doc() = "Python bindings for MSCCLPP: which is not NCCL";
|
||||
|
||||
@@ -216,62 +188,6 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
return nb::bytes(id.internal, sizeof(id.internal));
|
||||
});
|
||||
|
||||
nb::class_<_P2PHandle>(m, "_P2PHandle")
|
||||
.def_ro_static("__doc__", &DOC__P2PHandle)
|
||||
.def(
|
||||
"transport",
|
||||
&_P2PHandle::transport,
|
||||
"Get the transport type of the handle")
|
||||
.def(
|
||||
"data_ptr",
|
||||
[](const _P2PHandle& self) -> uint64_t {
|
||||
if (self.transport() == mscclppTransport_t::mscclppTransportP2P) {
|
||||
return reinterpret_cast<uint64_t>(self._rm.remoteBuff);
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"IB transport does not have a local data ptr");
|
||||
},
|
||||
"Get the local data pointer, only for P2P handles");
|
||||
|
||||
nb::class_<mscclppRegisteredMemory>(m, "_RegisteredMemory")
|
||||
.def(
|
||||
"handles",
|
||||
[](const mscclppRegisteredMemory& self) -> std::vector<_P2PHandle> {
|
||||
std::vector<_P2PHandle> handles;
|
||||
for (const auto& p2p : self.p2p) {
|
||||
handles.push_back(_P2PHandle(p2p));
|
||||
}
|
||||
return handles;
|
||||
},
|
||||
"Get the P2P handle for this memory")
|
||||
.def(
|
||||
"write_all",
|
||||
[](const mscclppRegisteredMemory& self,
|
||||
const _Comm& comm,
|
||||
uint64_t src_data,
|
||||
size_t size,
|
||||
uint32_t src_offset = 0,
|
||||
uint32_t dst_offset = 0,
|
||||
int64_t stream = 0) -> void {
|
||||
checkResult(
|
||||
mscclppRegisteredBufferWrite(
|
||||
comm._handle,
|
||||
const_cast<mscclppRegisteredMemory*>(&self),
|
||||
reinterpret_cast<void*>(src_data),
|
||||
size,
|
||||
src_offset,
|
||||
dst_offset,
|
||||
stream),
|
||||
"Failed to write to registered memory");
|
||||
},
|
||||
"comm"_a,
|
||||
"src_data"_a,
|
||||
"size"_a,
|
||||
"src_offset"_a = 0,
|
||||
"dst_offset"_a = 0,
|
||||
"stream"_a = 0,
|
||||
"Write to all bound targets in the buffer");
|
||||
|
||||
nb::class_<_Comm>(m, "_Comm")
|
||||
.def_ro_static("__doc__", &DOC__Comm)
|
||||
.def_static(
|
||||
@@ -309,54 +225,35 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"id"_a,
|
||||
"rank"_a,
|
||||
"world_size"_a,
|
||||
"Initialize comms given UniqueID, rank, and world_size")
|
||||
"Initialize comms given u UniqueID, rank, and world_size")
|
||||
.def(
|
||||
"opened",
|
||||
[](_Comm& self) -> bool { return self._is_open; },
|
||||
[](_Comm& comm) -> bool { return comm._is_open; },
|
||||
"Is this comm object opened?")
|
||||
.def(
|
||||
"closed",
|
||||
[](_Comm& self) -> bool { return !self._is_open; },
|
||||
[](_Comm& comm) -> bool { return !comm._is_open; },
|
||||
"Is this comm object closed?")
|
||||
.def_ro("rank", &_Comm::_rank)
|
||||
.def_ro("world_size", &_Comm::_world_size)
|
||||
.def(
|
||||
"register_buffer",
|
||||
[](_Comm& self,
|
||||
uint64_t data_ptr,
|
||||
uint64_t size) -> mscclppRegisteredMemory {
|
||||
self.check_open();
|
||||
mscclppRegisteredMemory regMem;
|
||||
checkResult(
|
||||
mscclppRegisterBuffer(
|
||||
self._handle,
|
||||
reinterpret_cast<void*>(data_ptr),
|
||||
size,
|
||||
®Mem),
|
||||
"Registering buffer failed");
|
||||
return regMem;
|
||||
;
|
||||
},
|
||||
"data_ptr"_a,
|
||||
"size"_a,
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Register a buffer for P2P transfers.")
|
||||
.def(
|
||||
"connect",
|
||||
[](_Comm& self,
|
||||
[](_Comm& comm,
|
||||
int remote_rank,
|
||||
int tag,
|
||||
uint64_t data_ptr,
|
||||
uint64_t size,
|
||||
uint64_t local_buff,
|
||||
uint64_t buff_size,
|
||||
mscclppTransport_t transport_type) -> void {
|
||||
self.check_open();
|
||||
if (comm._proxies_running) {
|
||||
throw std::invalid_argument("Proxy Threads Already Running");
|
||||
}
|
||||
RETRY(
|
||||
mscclppConnect(
|
||||
self._handle,
|
||||
comm._handle,
|
||||
remote_rank,
|
||||
tag,
|
||||
reinterpret_cast<void*>(data_ptr),
|
||||
size,
|
||||
reinterpret_cast<void*>(local_buff),
|
||||
buff_size,
|
||||
transport_type,
|
||||
NULL // ibDev
|
||||
),
|
||||
@@ -364,56 +261,64 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
},
|
||||
"remote_rank"_a,
|
||||
"tag"_a,
|
||||
"data_ptr"_a,
|
||||
"size"_a,
|
||||
"local_buf"_a,
|
||||
"buff_size"_a,
|
||||
"transport_type"_a,
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Attach a local buffer to a remote connection.")
|
||||
.def(
|
||||
"connection_setup",
|
||||
[](_Comm& self) -> void {
|
||||
self.check_open();
|
||||
RETRY(
|
||||
mscclppConnectionSetup(self._handle),
|
||||
"Failed to setup MSCCLPP connection");
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
RETRY(mscclppConnectionSetup(comm._handle),
|
||||
"Failed to setup MSCCLPP connection");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Run connection setup for MSCCLPP.")
|
||||
.def(
|
||||
"launch_proxies",
|
||||
[](_Comm& self) -> void {
|
||||
self.check_open();
|
||||
if (self._proxies_running) {
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
if (comm._proxies_running) {
|
||||
throw std::invalid_argument("Proxy Threads Already Running");
|
||||
}
|
||||
checkResult(
|
||||
mscclppProxyLaunch(self._handle),
|
||||
mscclppProxyLaunch(comm._handle),
|
||||
"Failed to launch MSCCLPP proxy");
|
||||
self._proxies_running = true;
|
||||
comm._proxies_running = true;
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Start the MSCCLPP proxy.")
|
||||
.def(
|
||||
"stop_proxies",
|
||||
[](_Comm& self) -> void {
|
||||
self.check_open();
|
||||
if (self._proxies_running) {
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
if (comm._proxies_running) {
|
||||
checkResult(
|
||||
mscclppProxyStop(self._handle),
|
||||
mscclppProxyStop(comm._handle),
|
||||
"Failed to stop MSCCLPP proxy");
|
||||
self._proxies_running = false;
|
||||
comm._proxies_running = false;
|
||||
}
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
"Start the MSCCLPP proxy.")
|
||||
.def("close", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("__del__", &_Comm::close, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"connection_setup",
|
||||
[](_Comm& comm) -> void {
|
||||
comm.check_open();
|
||||
checkResult(
|
||||
mscclppConnectionSetup(comm._handle),
|
||||
"Connection Setup Failed");
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(
|
||||
"bootstrap_all_gather_int",
|
||||
[](_Comm& self, int val) -> std::vector<int> {
|
||||
std::vector<int> buf(self._world_size);
|
||||
buf[self._rank] = val;
|
||||
mscclppBootstrapAllGather(self._handle, buf.data(), sizeof(int));
|
||||
[](_Comm& comm, int val) -> std::vector<int> {
|
||||
std::vector<int> buf(comm._world_size);
|
||||
buf[comm._rank] = val;
|
||||
mscclppBootstrapAllGather(comm._handle, buf.data(), sizeof(int));
|
||||
return buf;
|
||||
},
|
||||
nb::call_guard<nb::gil_scoped_release>(),
|
||||
@@ -421,13 +326,13 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"all-gather ints over the bootstrap connection.")
|
||||
.def(
|
||||
"all_gather_bytes",
|
||||
[](_Comm& self, nb::bytes& item) -> std::vector<nb::bytes> {
|
||||
[](_Comm& comm, nb::bytes& item) -> std::vector<nb::bytes> {
|
||||
// First, all-gather the sizes of all bytes.
|
||||
std::vector<size_t> sizes(self._world_size);
|
||||
sizes[self._rank] = item.size();
|
||||
std::vector<size_t> sizes(comm._world_size);
|
||||
sizes[comm._rank] = item.size();
|
||||
checkResult(
|
||||
mscclppBootstrapAllGather(
|
||||
self._handle, sizes.data(), sizeof(size_t)),
|
||||
comm._handle, sizes.data(), sizeof(size_t)),
|
||||
"bootstrapAllGather failed.");
|
||||
|
||||
// Next, find the largest message to send.
|
||||
@@ -435,21 +340,21 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
|
||||
// Allocate an all-gather buffer large enough for max * world_size.
|
||||
std::shared_ptr<char[]> data_buf(
|
||||
new char[max_size * self._world_size]);
|
||||
new char[max_size * comm._world_size]);
|
||||
|
||||
// Copy the local item into the buffer.
|
||||
std::memcpy(
|
||||
&data_buf[self._rank * max_size], item.c_str(), item.size());
|
||||
&data_buf[comm._rank * max_size], item.c_str(), item.size());
|
||||
|
||||
// all-gather the data buffer.
|
||||
checkResult(
|
||||
mscclppBootstrapAllGather(
|
||||
self._handle, data_buf.get(), max_size),
|
||||
comm._handle, data_buf.get(), max_size),
|
||||
"bootstrapAllGather failed.");
|
||||
|
||||
// Build a response vector.
|
||||
std::vector<nb::bytes> ret;
|
||||
for (int i = 0; i < self._world_size; ++i) {
|
||||
for (int i = 0; i < comm._world_size; ++i) {
|
||||
// Copy out the relevant range of each item.
|
||||
ret.push_back(nb::bytes(&data_buf[i * max_size], sizes[i]));
|
||||
}
|
||||
|
||||
@@ -10,9 +10,14 @@ logger = logging.getLogger(__file__)
|
||||
|
||||
from . import _py_mscclpp
|
||||
|
||||
__all__ = (
|
||||
"Comm",
|
||||
"MscclppUniqueId",
|
||||
"MSCCLPP_UNIQUE_ID_BYTES",
|
||||
"TransportType",
|
||||
)
|
||||
|
||||
_Comm = _py_mscclpp._Comm
|
||||
_RegisteredMemory = _py_mscclpp._RegisteredMemory
|
||||
_P2PHandle = _py_mscclpp._P2PHandle
|
||||
TransportType = _py_mscclpp.TransportType
|
||||
|
||||
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
|
||||
@@ -41,7 +46,6 @@ MSCCLPP_LOG_LEVELS: set[str] = {
|
||||
"TRACE",
|
||||
}
|
||||
|
||||
|
||||
def _setup_logging(level: str = "INFO"):
|
||||
"""Setup log hooks for the C library."""
|
||||
level = level.upper()
|
||||
@@ -61,7 +65,7 @@ _setup_logging()
|
||||
class Comm:
|
||||
"""Comm object; represents a mscclpp connection."""
|
||||
|
||||
_c_comm: _Comm
|
||||
_comm: _Comm
|
||||
|
||||
@staticmethod
|
||||
def init_rank_from_address(
|
||||
@@ -91,16 +95,16 @@ class Comm:
|
||||
|
||||
def __init__(self, *, _comm: _Comm):
|
||||
"""Construct a Comm object wrapping an internal _Comm handle."""
|
||||
self._c_comm = _comm
|
||||
self._comm = _comm
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
if self._c_comm:
|
||||
self._c_comm.close()
|
||||
self._c_comm = None
|
||||
if self._comm:
|
||||
self._comm.close()
|
||||
self._comm = None
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
@@ -108,7 +112,7 @@ class Comm:
|
||||
|
||||
Assumes the Comm is open.
|
||||
"""
|
||||
return self._c_comm.rank
|
||||
return self._comm.rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
@@ -116,11 +120,11 @@ class Comm:
|
||||
|
||||
Assumes the Comm is open.
|
||||
"""
|
||||
return self._c_comm.world_size
|
||||
return self._comm.world_size
|
||||
|
||||
def bootstrap_all_gather_int(self, val: int) -> list[int]:
|
||||
"""AllGather an int value through the bootstrap interface."""
|
||||
return self._c_comm.bootstrap_all_gather_int(val)
|
||||
return self._comm.bootstrap_all_gather_int(val)
|
||||
|
||||
def all_gather_bytes(self, item: bytes) -> list[bytes]:
|
||||
"""AllGather bytes (of different sizes) through the bootstrap interface.
|
||||
@@ -128,7 +132,7 @@ class Comm:
|
||||
:param item: the bytes object for this rank.
|
||||
:return: a list of bytes objects; the ret[rank] object will be a new copy.
|
||||
"""
|
||||
return self._c_comm.all_gather_bytes(item)
|
||||
return self._comm.all_gather_bytes(item)
|
||||
|
||||
def all_gather_json(self, item: Any) -> list[Any]:
|
||||
"""AllGather JSON objects through the bootstrap interface.
|
||||
@@ -157,7 +161,7 @@ class Comm:
|
||||
data_size: int,
|
||||
transport: int,
|
||||
) -> None:
|
||||
self._c_comm.connect(
|
||||
self._comm.connect(
|
||||
remote_rank,
|
||||
tag,
|
||||
data_ptr,
|
||||
@@ -165,108 +169,11 @@ class Comm:
|
||||
transport,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def connect_rank_from_address(
|
||||
cls,
|
||||
address: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_ptr: int,
|
||||
data_size: int,
|
||||
transport=TransportType.P2P,
|
||||
):
|
||||
comm = cls.init_rank_from_address(
|
||||
address=address,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
comm.connect(
|
||||
remote_rank=i,
|
||||
tag=0,
|
||||
data_ptr=data_ptr,
|
||||
data_size=data_size,
|
||||
transport=transport,
|
||||
)
|
||||
return comm
|
||||
|
||||
def connection_setup(self) -> None:
|
||||
self._c_comm.connection_setup()
|
||||
self._comm.connection_setup()
|
||||
|
||||
def launch_proxies(self) -> None:
|
||||
self._c_comm.launch_proxies()
|
||||
self._comm.launch_proxies()
|
||||
|
||||
def stop_proxies(self) -> None:
|
||||
self._c_comm.stop_proxies()
|
||||
|
||||
def register_buffer(
|
||||
self,
|
||||
data_ptr: int,
|
||||
size: int,
|
||||
) -> "RegisteredMemory":
|
||||
return RegisteredMemory(
|
||||
comm=self,
|
||||
rm=self._c_comm.register_buffer(
|
||||
data_ptr=data_ptr,
|
||||
size=size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RegisteredMemory:
|
||||
_comm: Comm
|
||||
_c_rm: _RegisteredMemory
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
comm: Comm,
|
||||
rm: _RegisteredMemory,
|
||||
):
|
||||
self._comm = comm
|
||||
self._c_rm = rm
|
||||
|
||||
def handles(self) -> list["P2PHandle"]:
|
||||
return [
|
||||
P2PHandle(
|
||||
comm=self._comm,
|
||||
handle=h,
|
||||
)
|
||||
for h in self._c_rm.handles()
|
||||
]
|
||||
|
||||
def _write(
|
||||
self,
|
||||
src_ptr: int,
|
||||
size: int,
|
||||
*,
|
||||
src_offset: int = 0,
|
||||
dst_offset: int = 0,
|
||||
stream: int = 0,
|
||||
) -> None:
|
||||
self._c_rm.write_all(
|
||||
comm=self._comm._c_comm,
|
||||
src_data=src_ptr,
|
||||
size=size,
|
||||
src_offset=src_offset,
|
||||
dst_offset=dst_offset,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
class P2PHandle:
|
||||
_comm: Comm
|
||||
_c_handle: _P2PHandle
|
||||
|
||||
def __init__(self, *, comm: Comm, handle: _P2PHandle):
|
||||
self._comm = comm
|
||||
self._c_handle = handle
|
||||
|
||||
def transport(self) -> TransportType:
|
||||
return self._c_handle.transport()
|
||||
|
||||
def data_ptr(self) -> int:
|
||||
return self._c_handle.data_ptr()
|
||||
self._comm.stop_proxies()
|
||||
|
||||
@@ -79,8 +79,6 @@ class CommsTest(unittest.TestCase):
|
||||
if errors:
|
||||
parts = []
|
||||
for rank, content in errors:
|
||||
parts.append(
|
||||
f"[rank {rank}]: " + content.decode("utf-8", errors="ignore")
|
||||
)
|
||||
parts.append(f"[rank {rank}]: " + content.decode('utf-8', errors='ignore'))
|
||||
|
||||
raise AssertionError("\n\n".join(parts))
|
||||
|
||||
@@ -73,7 +73,7 @@ def _test_bootstrap_allgather_pickle(options: argparse.Namespace, comm: mscclpp.
|
||||
comm.connection_setup()
|
||||
|
||||
|
||||
def _test_rm(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
def _test_p2p_connect(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
rank = options.rank
|
||||
|
||||
buf = torch.zeros([options.world_size], dtype=torch.int64)
|
||||
@@ -95,23 +95,6 @@ def _test_rm(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
mscclpp.TransportType.P2P,
|
||||
)
|
||||
|
||||
rm = comm.register_buffer(buf.data_ptr(), buf.element_size() * buf.numel())
|
||||
handles = rm.handles()
|
||||
hamcrest.assert_that(
|
||||
handles,
|
||||
hamcrest.has_length(options.world_size - 1),
|
||||
)
|
||||
for handle in handles:
|
||||
hamcrest.assert_that(
|
||||
handle.transport(),
|
||||
hamcrest.equal_to(mscclpp.TransportType.P2P),
|
||||
)
|
||||
# assuming P2P ...
|
||||
hamcrest.assert_that(
|
||||
handle.data_ptr(),
|
||||
hamcrest.greater_than(0),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
comm.connection_setup()
|
||||
@@ -120,6 +103,7 @@ def _test_rm(options: argparse.Namespace, comm: mscclpp.Comm):
|
||||
comm.stop_proxies()
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--rank", type=int, required=True)
|
||||
@@ -147,7 +131,7 @@ def main():
|
||||
_test_bootstrap_allgather_bytes(options, comm)
|
||||
_test_bootstrap_allgather_json(options, comm)
|
||||
_test_bootstrap_allgather_pickle(options, comm)
|
||||
_test_rm(options, comm)
|
||||
_test_p2p_connect(options, comm)
|
||||
finally:
|
||||
comm.close()
|
||||
|
||||
|
||||
@@ -20,8 +20,6 @@
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define CUDA_CHECK(cmd) (cmd)
|
||||
|
||||
#define CUDACHECKGOTO(cmd, res, label) \
|
||||
do { \
|
||||
cudaError_t err = cmd; \
|
||||
|
||||
@@ -12,6 +12,21 @@
|
||||
#define MSCCLPP_IB_MAX_SENDS 64
|
||||
#define MSCCLPP_IB_MAX_DEVS 8
|
||||
|
||||
// MR info to be shared with the remote peer
|
||||
struct mscclppIbMrInfo
|
||||
{
|
||||
uint64_t addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
|
||||
// IB memory region
|
||||
struct mscclppIbMr
|
||||
{
|
||||
struct ibv_mr* mr;
|
||||
void* buff;
|
||||
struct mscclppIbMrInfo info;
|
||||
};
|
||||
|
||||
// QP info to be shared with the remote peer
|
||||
struct mscclppIbQpInfo
|
||||
{
|
||||
|
||||
@@ -13,11 +13,9 @@
|
||||
|
||||
#include <mscclppfifo.h>
|
||||
#include <time.h>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C"
|
||||
{
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/***************************************************************************************************************
|
||||
@@ -176,32 +174,6 @@ typedef struct
|
||||
char internal[MSCCLPP_UNIQUE_ID_BYTES];
|
||||
} mscclppUniqueId;
|
||||
|
||||
// MR info to be shared with the remote peer
|
||||
struct mscclppIbMrInfo
|
||||
{
|
||||
uint64_t addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
|
||||
// IB memory region
|
||||
struct mscclppIbMr
|
||||
{
|
||||
struct ibv_mr* mr;
|
||||
void* buff;
|
||||
struct mscclppIbMrInfo info;
|
||||
};
|
||||
|
||||
struct mscclppRegisteredMemoryP2P
|
||||
{
|
||||
void* remoteBuff;
|
||||
mscclppIbMr* IbMr;
|
||||
};
|
||||
|
||||
struct mscclppRegisteredMemory
|
||||
{
|
||||
std::vector<mscclppRegisteredMemoryP2P> p2p;
|
||||
};
|
||||
|
||||
/* Error type */
|
||||
typedef enum
|
||||
{
|
||||
@@ -405,33 +377,6 @@ void mscclppDefaultLogHandler(const char* msg);
|
||||
*/
|
||||
mscclppResult_t mscclppSetLogHandler(mscclppLogHandler_t handler);
|
||||
|
||||
/* Register a buffer for RDMA.
|
||||
*
|
||||
* Outputs:
|
||||
* regMem: the registered memory
|
||||
*
|
||||
* Inputs:
|
||||
* comm: the communicator
|
||||
* local_memory: the local buffer to be registered
|
||||
* size: the size of the buffer
|
||||
*/
|
||||
mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size,
|
||||
mscclppRegisteredMemory* regMem);
|
||||
|
||||
/* Write to a registered buffer.
|
||||
*
|
||||
* Inputs:
|
||||
* comm: the communicator
|
||||
* regMem: the registered memory
|
||||
* srcBuff: the source buffer
|
||||
* size: the size of the buffer
|
||||
* srcOffset: the offset of the source buffer
|
||||
* dstOffset: the offset of the destination buffer
|
||||
* stream: the CUDA stream
|
||||
*/
|
||||
mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegisteredMemory* regMem, void* srcBuff,
|
||||
size_t size, uint32_t srcOffset, uint32_t dstOffset, int64_t stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif
|
||||
|
||||
91
src/init.cc
91
src/init.cc
@@ -11,8 +11,6 @@
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
static uint64_t hashUniqueId(mscclppUniqueId const& id)
|
||||
{
|
||||
char const* bytes = (char const*)&id;
|
||||
@@ -562,95 +560,6 @@ mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
struct bufferInfo
|
||||
{
|
||||
cudaIpcMemHandle_t handleBuff;
|
||||
int64_t handleBuffOffset;
|
||||
mscclppIbMrInfo infoBuffMr;
|
||||
};
|
||||
|
||||
MSCCLPP_API(mscclppResult_t, mscclppRegisterBuffer, mscclppComm_t comm, void* local_memory, size_t size,
|
||||
mscclppRegisteredMemory* regMem);
|
||||
mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size,
|
||||
mscclppRegisteredMemory* regMem)
|
||||
{
|
||||
std::vector<struct mscclppIbMr*> ibMrs;
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
struct bufferInfo bInfo;
|
||||
struct mscclppIbMr* ibBuffMr;
|
||||
|
||||
// TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
int64_t base;
|
||||
CUDA_CHECK(cuMemGetAddressRange((CUdeviceptr*)&base, NULL, (CUdeviceptr)local_memory));
|
||||
bInfo.handleBuffOffset = (int64_t)local_memory - base;
|
||||
CUDACHECK(cudaIpcGetMemHandle(&bInfo.handleBuff, (void*)base));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(conn->ibCtx, local_memory, size, &ibBuffMr));
|
||||
bInfo.infoBuffMr = ibBuffMr->info;
|
||||
ibMrs.push_back(ibBuffMr);
|
||||
}
|
||||
|
||||
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &bInfo, sizeof(bInfo)));
|
||||
}
|
||||
|
||||
// Recv info from peers
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
struct bufferInfo bInfo;
|
||||
|
||||
mscclppRegisteredMemoryP2P p2p;
|
||||
p2p.IbMr = NULL;
|
||||
p2p.remoteBuff = NULL;
|
||||
MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &bInfo, sizeof(bInfo)));
|
||||
|
||||
// TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void**)&p2p.remoteBuff, bInfo.handleBuff, cudaIpcMemLazyEnablePeerAccess));
|
||||
p2p.remoteBuff = (void*)((int64_t)p2p.remoteBuff + bInfo.handleBuffOffset);
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
p2p.IbMr = ibMrs[i];
|
||||
}
|
||||
regMem->p2p.push_back(p2p);
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API(mscclppResult_t, mscclppRegisteredBufferWrite, mscclppComm_t comm, mscclppRegisteredMemory* regMem,
|
||||
void* srcBuff, size_t size, uint32_t srcOffset, uint32_t dstOffset, int64_t stream);
|
||||
mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, mscclppRegisteredMemory* regMem, void* srcBuff,
|
||||
size_t size, uint32_t srcOffset, uint32_t dstOffset, int64_t stream)
|
||||
{
|
||||
int ret = 0;
|
||||
// TODO: transport should be an argument too so user can decide which transport to use
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn* conn = &comm->conns[i];
|
||||
// TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
void* dstBuff = regMem->p2p[i].remoteBuff + dstOffset;
|
||||
void* src = srcBuff + srcOffset;
|
||||
CUDACHECK(cudaMemcpyAsync(dstBuff, src, size, cudaMemcpyDeviceToDevice, (cudaStream_t)stream));
|
||||
// INFO(MSCCLPP_INIT, "data memcpyAsync %p -> %p, size %zu", src, dstBuff, size);
|
||||
} else {
|
||||
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)size,
|
||||
/*wrId=*/0, /*srcOffset=*/srcOffset,
|
||||
/*dstOffset=*/dstOffset,
|
||||
/*signaled=*/false);
|
||||
if ((ret = conn->ibQp->postSend()) != 0) {
|
||||
// Return value is errno.
|
||||
WARN("data postSend failed: errno %d", ret);
|
||||
}
|
||||
// ??
|
||||
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_ENTRY, (uint32_t)trigger.fields.dataSize,
|
||||
// trigger.fields.connId);
|
||||
}
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
// TODO: destroy registered buffer
|
||||
|
||||
MSCCLPP_API(mscclppResult_t, mscclppProxyLaunch, mscclppComm_t comm);
|
||||
mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user