mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
format
This commit is contained in:
@@ -128,18 +128,18 @@ struct _Comm {
|
||||
};
|
||||
|
||||
struct _P2PHandle {
|
||||
struct mscclppRegisteredMemoryP2P _rmP2P;
|
||||
struct mscclppIbMr _ibmr;
|
||||
struct mscclppRegisteredMemoryP2P _rmP2P;
|
||||
struct mscclppIbMr _ibmr;
|
||||
|
||||
_P2PHandle() : _rmP2P({0}), _ibmr({0}) {}
|
||||
_P2PHandle() : _rmP2P({0}), _ibmr({0}) {}
|
||||
|
||||
_P2PHandle(const mscclppRegisteredMemoryP2P &p2p): _ibmr({0}) {
|
||||
_rmP2P = p2p;
|
||||
if (_rmP2P.IbMr != nullptr) {
|
||||
_ibmr = *_rmP2P.IbMr;
|
||||
_rmP2P.IbMr = &_ibmr;
|
||||
}
|
||||
_P2PHandle(const mscclppRegisteredMemoryP2P& p2p) : _ibmr({0}) {
|
||||
_rmP2P = p2p;
|
||||
if (_rmP2P.IbMr != nullptr) {
|
||||
_ibmr = *_rmP2P.IbMr;
|
||||
_rmP2P.IbMr = &_ibmr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
nb::callable _log_callback;
|
||||
@@ -261,7 +261,9 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
.def_ro("world_size", &_Comm::_world_size)
|
||||
.def(
|
||||
"register_buffer",
|
||||
[](_Comm& comm, uint64_t local_buff, uint64_t buff_size) -> std::vector<_P2PHandle> {
|
||||
[](_Comm& comm,
|
||||
uint64_t local_buff,
|
||||
uint64_t buff_size) -> std::vector<_P2PHandle> {
|
||||
comm.check_open();
|
||||
mscclppRegisteredMemory regMem;
|
||||
checkResult(
|
||||
@@ -273,8 +275,8 @@ NB_MODULE(_py_mscclpp, m) {
|
||||
"Registering buffer failed");
|
||||
|
||||
std::vector<_P2PHandle> handles;
|
||||
for (const auto &p2p : regMem.p2p) {
|
||||
handles.push_back(_P2PHandle(p2p));
|
||||
for (const auto& p2p : regMem.p2p) {
|
||||
handles.push_back(_P2PHandle(p2p));
|
||||
}
|
||||
return handles;
|
||||
},
|
||||
|
||||
@@ -71,11 +71,11 @@ class Comm:
|
||||
|
||||
@staticmethod
|
||||
def init_rank_from_address(
|
||||
address: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
address: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
*,
|
||||
port: Optional[int] = None,
|
||||
):
|
||||
"""Initialize a Comm from an address.
|
||||
|
||||
@@ -156,12 +156,12 @@ class Comm:
|
||||
return [pickle.loads(b) for b in self.all_gather_bytes(pickle.dumps(item))]
|
||||
|
||||
def connect(
|
||||
self,
|
||||
remote_rank: int,
|
||||
tag: int,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
transport: int,
|
||||
self,
|
||||
remote_rank: int,
|
||||
tag: int,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
transport: int,
|
||||
) -> None:
|
||||
self._comm.connect(
|
||||
remote_rank,
|
||||
@@ -181,12 +181,13 @@ class Comm:
|
||||
self._comm.stop_proxies()
|
||||
|
||||
def register_buffer(
|
||||
self,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
self,
|
||||
data_ptr,
|
||||
data_size: int,
|
||||
) -> list[_P2PHandle]:
|
||||
return [
|
||||
P2PHandle(self, h) for h in self._comm.register_buffer(
|
||||
P2PHandle(self, h)
|
||||
for h in self._comm.register_buffer(
|
||||
data_ptr,
|
||||
data_size,
|
||||
)
|
||||
@@ -197,8 +198,6 @@ class P2PHandle:
|
||||
_comm: Comm
|
||||
_handle: _P2PHandle
|
||||
|
||||
def __init__(self,
|
||||
comm: Comm,
|
||||
handle: _P2PHandle):
|
||||
def __init__(self, comm: Comm, handle: _P2PHandle):
|
||||
self._comm = comm
|
||||
self._handle = handle
|
||||
|
||||
Reference in New Issue
Block a user