This commit is contained in:
Crutcher Dunnavant
2023-04-07 19:12:05 -07:00
parent b25bcf5f93
commit 19962a8002
2 changed files with 31 additions and 30 deletions

View File

@@ -128,18 +128,18 @@ struct _Comm {
};
struct _P2PHandle {
struct mscclppRegisteredMemoryP2P _rmP2P;
struct mscclppIbMr _ibmr;
struct mscclppRegisteredMemoryP2P _rmP2P;
struct mscclppIbMr _ibmr;
_P2PHandle() : _rmP2P({0}), _ibmr({0}) {}
_P2PHandle() : _rmP2P({0}), _ibmr({0}) {}
_P2PHandle(const mscclppRegisteredMemoryP2P &p2p): _ibmr({0}) {
_rmP2P = p2p;
if (_rmP2P.IbMr != nullptr) {
_ibmr = *_rmP2P.IbMr;
_rmP2P.IbMr = &_ibmr;
}
_P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _ibmr({0}) {
_rmP2P = p2p;
if (_rmP2P.IbMr != nullptr) {
_ibmr = *_rmP2P.IbMr;
_rmP2P.IbMr = &_ibmr;
}
}
};
nb::callable _log_callback;
@@ -261,7 +261,9 @@ NB_MODULE(_py_mscclpp, m) {
.def_ro("world_size", &_Comm::_world_size)
.def(
"register_buffer",
[](_Comm& comm, uint64_t local_buff, uint64_t buff_size) -> std::vector<_P2PHandle> {
[](_Comm& comm,
uint64_t local_buff,
uint64_t buff_size) -> std::vector<_P2PHandle> {
comm.check_open();
mscclppRegisteredMemory regMem;
checkResult(
@@ -273,8 +275,8 @@ NB_MODULE(_py_mscclpp, m) {
"Registering buffer failed");
std::vector<_P2PHandle> handles;
for (const auto &p2p : regMem.p2p) {
handles.push_back(_P2PHandle(p2p));
for (const auto& p2p : regMem.p2p) {
handles.push_back(_P2PHandle(p2p));
}
return handles;
},

View File

@@ -71,11 +71,11 @@ class Comm:
@staticmethod
def init_rank_from_address(
address: str,
rank: int,
world_size: int,
*,
port: Optional[int] = None,
address: str,
rank: int,
world_size: int,
*,
port: Optional[int] = None,
):
"""Initialize a Comm from an address.
@@ -156,12 +156,12 @@ class Comm:
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
def connect(
self,
remote_rank: int,
tag: int,
data_ptr,
data_size: int,
transport: int,
self,
remote_rank: int,
tag: int,
data_ptr,
data_size: int,
transport: int,
) -> None:
self._comm.connect(
remote_rank,
@@ -181,12 +181,13 @@ class Comm:
self._comm.stop_proxies()
def register_buffer(
self,
data_ptr,
data_size: int,
self,
data_ptr,
data_size: int,
) -> list[_P2PHandle]:
return [
P2PHandle(self, h) for h in self._comm.register_buffer(
P2PHandle(self, h)
for h in self._comm.register_buffer(
data_ptr,
data_size,
)
@@ -197,8 +198,6 @@ class P2PHandle:
_comm: Comm
_handle: _P2PHandle
def __init__(self,
comm: Comm,
handle: _P2PHandle):
def __init__(self, comm: Comm, handle: _P2PHandle):
self._comm = comm
self._handle = handle