mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Use multiple thread blocks; Add peer-parallel kernels
This commit is contained in:
@@ -149,6 +149,22 @@ class MscclppAlltoAllV:
|
||||
self._d_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
||||
self._d_remote_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
||||
|
||||
# Cache for split sizes to avoid redundant bootstrap exchanges and GPU copies.
|
||||
# Key: (tuple(send_counts_bytes), tuple(recv_counts_bytes))
|
||||
self._cached_splits_key = None
|
||||
self._cached_input_size = 0
|
||||
self._cached_output_size = 0
|
||||
self._cached_total_output_elems = 0
|
||||
self._cached_dtype = None
|
||||
# Pre-built extras dict (GPU pointers don't change)
|
||||
self._extras = {
|
||||
"sendCounts": self._d_send_counts.data_ptr(),
|
||||
"sendDispls": self._d_send_displs.data_ptr(),
|
||||
"recvCounts": self._d_recv_counts.data_ptr(),
|
||||
"recvDispls": self._d_recv_displs.data_ptr(),
|
||||
"remoteRecvDispls": self._d_remote_recv_displs.data_ptr(),
|
||||
}
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._rank
|
||||
@@ -219,35 +235,32 @@ class MscclppAlltoAllV:
|
||||
send_displs_bytes = [d * elem_size for d in send_displs]
|
||||
recv_counts_bytes = [s * elem_size for s in output_split_sizes]
|
||||
recv_displs_bytes = [d * elem_size for d in recv_displs]
|
||||
|
||||
# Copy to GPU
|
||||
self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64))
|
||||
self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64))
|
||||
self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64))
|
||||
self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64))
|
||||
|
||||
# Exchange recv displacements with all peers so each rank knows where to
|
||||
# write in the remote output buffer. remoteRecvDispls[peer] = peer's
|
||||
# recvDispls[rank], i.e. the offset in peer's output where our data goes.
|
||||
remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes)
|
||||
self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64))
|
||||
|
||||
# Fast path: skip GPU copies + bootstrap exchange if split sizes unchanged
|
||||
splits_key = (tuple(send_counts_bytes), tuple(recv_counts_bytes))
|
||||
if splits_key != self._cached_splits_key:
|
||||
# Copy counts/displacements to GPU
|
||||
self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64))
|
||||
self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64))
|
||||
self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64))
|
||||
self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64))
|
||||
|
||||
# Exchange recv displacements with peers via bootstrap
|
||||
remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes)
|
||||
self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64))
|
||||
|
||||
# Cache for subsequent calls
|
||||
self._cached_splits_key = splits_key
|
||||
self._cached_input_size = sum(send_counts_bytes)
|
||||
self._cached_output_size = sum(recv_counts_bytes)
|
||||
|
||||
# Get stream
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream()
|
||||
cuda_stream = stream.cuda_stream
|
||||
|
||||
# Build extras dict with GPU pointers
|
||||
extras = {
|
||||
"sendCounts": self._d_send_counts.data_ptr(),
|
||||
"sendDispls": self._d_send_displs.data_ptr(),
|
||||
"recvCounts": self._d_recv_counts.data_ptr(),
|
||||
"recvDispls": self._d_recv_displs.data_ptr(),
|
||||
"remoteRecvDispls": self._d_remote_recv_displs.data_ptr(),
|
||||
}
|
||||
|
||||
input_size = sum(send_counts_bytes)
|
||||
output_size = sum(recv_counts_bytes)
|
||||
|
||||
input_size = self._cached_input_size
|
||||
output_size = self._cached_output_size
|
||||
|
||||
# Execute the optimized kernel
|
||||
result = self._algo.execute(
|
||||
@@ -262,7 +275,7 @@ class MscclppAlltoAllV:
|
||||
None, # executor (not needed for native algos)
|
||||
0, # nblocks (auto)
|
||||
0, # nthreads_per_block (auto)
|
||||
extras,
|
||||
self._extras,
|
||||
)
|
||||
|
||||
if result != 0:
|
||||
|
||||
Reference in New Issue
Block a user