Use multiple thread blocks; Add peer-parallel kernels

This commit is contained in:
Qinghua Zhou
2026-02-24 04:05:01 +00:00
parent 21e3f1ebb3
commit f803eff8b9
4 changed files with 281 additions and 44 deletions

View File

@@ -149,6 +149,22 @@ class MscclppAlltoAllV:
self._d_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
self._d_remote_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
# Cache for split sizes to avoid redundant bootstrap exchanges and GPU copies.
# Key: (tuple(send_counts_bytes), tuple(recv_counts_bytes))
self._cached_splits_key = None
self._cached_input_size = 0
self._cached_output_size = 0
self._cached_total_output_elems = 0
self._cached_dtype = None
# Pre-built extras dict (GPU pointers don't change)
self._extras = {
"sendCounts": self._d_send_counts.data_ptr(),
"sendDispls": self._d_send_displs.data_ptr(),
"recvCounts": self._d_recv_counts.data_ptr(),
"recvDispls": self._d_recv_displs.data_ptr(),
"remoteRecvDispls": self._d_remote_recv_displs.data_ptr(),
}
@property
def rank(self) -> int:
return self._rank
@@ -219,35 +235,32 @@ class MscclppAlltoAllV:
send_displs_bytes = [d * elem_size for d in send_displs]
recv_counts_bytes = [s * elem_size for s in output_split_sizes]
recv_displs_bytes = [d * elem_size for d in recv_displs]
# Copy to GPU
self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64))
self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64))
self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64))
self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64))
# Exchange recv displacements with all peers so each rank knows where to
# write in the remote output buffer. remoteRecvDispls[peer] = peer's
# recvDispls[rank], i.e. the offset in peer's output where our data goes.
remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes)
self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64))
# Fast path: skip GPU copies + bootstrap exchange if split sizes unchanged
splits_key = (tuple(send_counts_bytes), tuple(recv_counts_bytes))
if splits_key != self._cached_splits_key:
# Copy counts/displacements to GPU
self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64))
self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64))
self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64))
self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64))
# Exchange recv displacements with peers via bootstrap
remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes)
self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64))
# Cache for subsequent calls
self._cached_splits_key = splits_key
self._cached_input_size = sum(send_counts_bytes)
self._cached_output_size = sum(recv_counts_bytes)
# Get stream
if stream is None:
stream = torch.cuda.current_stream()
cuda_stream = stream.cuda_stream
# Build extras dict with GPU pointers
extras = {
"sendCounts": self._d_send_counts.data_ptr(),
"sendDispls": self._d_send_displs.data_ptr(),
"recvCounts": self._d_recv_counts.data_ptr(),
"recvDispls": self._d_recv_displs.data_ptr(),
"remoteRecvDispls": self._d_remote_recv_displs.data_ptr(),
}
input_size = sum(send_counts_bytes)
output_size = sum(recv_counts_bytes)
input_size = self._cached_input_size
output_size = self._cached_output_size
# Execute the optimized kernel
result = self._algo.execute(
@@ -262,7 +275,7 @@ class MscclppAlltoAllV:
None, # executor (not needed for native algos)
0, # nblocks (auto)
0, # nthreads_per_block (auto)
extras,
self._extras,
)
if result != 0: