mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Get correct remote receive displacements for peers
This commit is contained in:
@@ -147,6 +147,7 @@ class MscclppAlltoAllV:
|
||||
self._d_send_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
||||
self._d_recv_counts = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
||||
self._d_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
||||
self._d_remote_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
@@ -225,6 +226,12 @@ class MscclppAlltoAllV:
|
||||
self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64))
|
||||
self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64))
|
||||
|
||||
# Exchange recv displacements with all peers so each rank knows where to
|
||||
# write in the remote output buffer. remoteRecvDispls[peer] = peer's
|
||||
# recvDispls[rank], i.e. the offset in peer's output where our data goes.
|
||||
remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes)
|
||||
self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64))
|
||||
|
||||
# Get stream
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream()
|
||||
@@ -236,6 +243,7 @@ class MscclppAlltoAllV:
|
||||
"sendDispls": self._d_send_displs.data_ptr(),
|
||||
"recvCounts": self._d_recv_counts.data_ptr(),
|
||||
"recvDispls": self._d_recv_displs.data_ptr(),
|
||||
"remoteRecvDispls": self._d_remote_recv_displs.data_ptr(),
|
||||
}
|
||||
|
||||
input_size = sum(send_counts_bytes)
|
||||
@@ -262,6 +270,57 @@ class MscclppAlltoAllV:
|
||||
|
||||
return output
|
||||
|
||||
def _exchange_recv_displs(self, recv_displs_bytes: list) -> list:
|
||||
"""
|
||||
Exchange recv displacement arrays between all ranks via bootstrap send/recv.
|
||||
|
||||
Each rank needs to know where to write in each peer's output buffer.
|
||||
remoteRecvDispls[peer] = peer's recvDispls[rank] — the byte offset in
|
||||
peer's output buffer where data from this rank should be placed.
|
||||
|
||||
Uses a deadlock-free pattern: lower-ranked peer sends first, then receives.
|
||||
|
||||
Args:
|
||||
recv_displs_bytes: This rank's recv displacement array (in bytes)
|
||||
|
||||
Returns:
|
||||
List of remote recv displacements (one per rank, in bytes).
|
||||
remoteRecvDispls[rank] == recv_displs_bytes[rank] (self, unused by kernel)
|
||||
"""
|
||||
rank = self._rank
|
||||
world_size = self._world_size
|
||||
bootstrap = self._comm.bootstrap()
|
||||
|
||||
# Use CPU int64 tensors as send/recv buffers (data_ptr() gives uintptr_t)
|
||||
my_displs = torch.tensor(recv_displs_bytes, dtype=torch.int64) # CPU tensor
|
||||
data_size = world_size * 8 # world_size int64 values = world_size * 8 bytes
|
||||
|
||||
# Collect all ranks' recv_displs via pairwise send/recv
|
||||
all_recv_displs = [None] * world_size
|
||||
all_recv_displs[rank] = list(recv_displs_bytes)
|
||||
|
||||
for peer in range(world_size):
|
||||
if peer == rank:
|
||||
continue
|
||||
peer_buf = torch.zeros(world_size, dtype=torch.int64) # CPU tensor
|
||||
if rank < peer:
|
||||
# Lower rank sends first to avoid deadlock
|
||||
bootstrap.send(my_displs.data_ptr(), data_size, peer, 0)
|
||||
bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0)
|
||||
else:
|
||||
# Higher rank receives first
|
||||
bootstrap.recv(peer_buf.data_ptr(), data_size, peer, 0)
|
||||
bootstrap.send(my_displs.data_ptr(), data_size, peer, 0)
|
||||
all_recv_displs[peer] = peer_buf.tolist()
|
||||
|
||||
# Build remoteRecvDispls: for each peer, what offset in peer's output
|
||||
# buffer should this rank's data go to?
|
||||
remote_recv_displs = []
|
||||
for peer in range(world_size):
|
||||
remote_recv_displs.append(int(all_recv_displs[peer][rank]))
|
||||
|
||||
return remote_recv_displs
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup resources."""
|
||||
# Let CUDA handle tensor cleanup automatically
|
||||
|
||||
@@ -109,14 +109,25 @@ def main():
|
||||
if rank == 0:
|
||||
print("\n[Test 2] Variable-size all-to-all (MoE-like)")
|
||||
|
||||
# Simulate MoE token distribution: rank 0 sends more to rank 0, etc.
|
||||
input_split_sizes = [(i + 1) * 512 for i in range(world_size)]
|
||||
output_split_sizes = [512 * (rank + 1)] * world_size
|
||||
# Simulate MoE token distribution with imbalanced routing.
|
||||
# Build a full send matrix so each rank has different per-peer sizes.
|
||||
# send_matrix[i][j] = number of elements rank i sends to rank j.
|
||||
# For consistency: rank i's output_split[j] = send_matrix[j][i].
|
||||
import random
|
||||
random.seed(42)
|
||||
send_matrix = []
|
||||
for i in range(world_size):
|
||||
row = [random.randint(128, 2048) for _ in range(world_size)]
|
||||
send_matrix.append(row)
|
||||
|
||||
input_split_sizes = send_matrix[rank] # what this rank sends to each peer
|
||||
output_split_sizes = [send_matrix[j][rank] for j in range(world_size)] # what this rank receives from each peer
|
||||
|
||||
total_input = sum(input_split_sizes)
|
||||
total_output = sum(output_split_sizes)
|
||||
|
||||
input_tensor = torch.randn(total_input, dtype=torch.float32, device='cuda')
|
||||
# Fill input with rank-specific pattern for verification
|
||||
input_tensor = torch.arange(total_input, dtype=torch.float32, device='cuda') + rank * 100000
|
||||
output_tensor = torch.empty(total_output, dtype=torch.float32, device='cuda')
|
||||
|
||||
output = alltoallv.all_to_all_single(
|
||||
@@ -127,43 +138,67 @@ def main():
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Verify: the local-to-local segment should match exactly
|
||||
local_send_offset = sum(input_split_sizes[:rank])
|
||||
local_recv_offset = sum(output_split_sizes[:rank])
|
||||
local_size = input_split_sizes[rank] # == output_split_sizes[rank]
|
||||
expected_local = input_tensor[local_send_offset:local_send_offset + local_size]
|
||||
actual_local = output_tensor[local_recv_offset:local_recv_offset + local_size]
|
||||
local_ok = torch.allclose(expected_local, actual_local)
|
||||
|
||||
if rank == 0:
|
||||
print(f" Input splits: {input_split_sizes}")
|
||||
print(f" Output splits: {output_split_sizes}")
|
||||
print(f" Send matrix row (rank 0 sends): {input_split_sizes}")
|
||||
print(f" Recv sizes (rank 0 receives): {output_split_sizes}")
|
||||
print(f" Input total: {total_input}, Output total: {total_output}")
|
||||
print(f" PASS")
|
||||
print(f" Local copy verified: {local_ok}")
|
||||
print(f" {'PASS' if local_ok else 'FAIL'}")
|
||||
|
||||
# Test 3: Performance benchmark
|
||||
# Test 3: Performance benchmark across message sizes (1KB to 128MB)
|
||||
if rank == 0:
|
||||
print("\n[Test 3] Performance benchmark (1MB per rank)")
|
||||
|
||||
msg_size = 1024 * 1024 # 1MB per message
|
||||
input_size = msg_size * world_size
|
||||
|
||||
input_tensor = torch.randn(input_size // 4, dtype=torch.float32, device='cuda') # 4 bytes per float
|
||||
output_tensor = torch.empty_like(input_tensor)
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
output = alltoallv.all_to_all_single(input_tensor, output=output_tensor)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
n_iters = 20
|
||||
start = time.perf_counter()
|
||||
for _ in range(n_iters):
|
||||
output = alltoallv.all_to_all_single(input_tensor, output=output_tensor)
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
# Calculate bandwidth
|
||||
total_bytes = 2 * input_size * n_iters # read + write
|
||||
bandwidth_gbps = total_bytes / elapsed / 1e9
|
||||
|
||||
if rank == 0:
|
||||
print(f" {n_iters} iterations in {elapsed*1000:.2f} ms")
|
||||
print(f" Bandwidth: {bandwidth_gbps:.2f} GB/s")
|
||||
print(f" Per-iteration: {elapsed/n_iters*1000:.3f} ms")
|
||||
print("\n[Test 3] Performance benchmark (1KB to 128MB per rank)")
|
||||
print(f" {'Msg Size':>10s} {'Iters':>5s} {'Total (ms)':>10s} {'Lat (us)':>10s} {'BW (GB/s)':>10s}")
|
||||
print(f" {'-'*10} {'-'*5} {'-'*10} {'-'*10} {'-'*10}")
|
||||
|
||||
# Message sizes: 1KB, 4KB, 16KB, 64KB, 256KB, 1MB, 4MB, 16MB, 64MB, 128MB
|
||||
msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0] # powers of 4 from 1KB to 64MB
|
||||
msg_sizes.append(128 * 1024 * 1024) # add 128MB
|
||||
|
||||
for msg_size in msg_sizes:
|
||||
input_size = msg_size * world_size
|
||||
n_elems = input_size // 4 # float32 = 4 bytes
|
||||
|
||||
input_tensor = torch.randn(n_elems, dtype=torch.float32, device='cuda')
|
||||
output_tensor = torch.empty_like(input_tensor)
|
||||
|
||||
# Fewer warmup/iters for very large sizes
|
||||
n_warmup = 3 if msg_size >= 16 * 1024 * 1024 else 5
|
||||
n_iters = 5 if msg_size >= 64 * 1024 * 1024 else (10 if msg_size >= 4 * 1024 * 1024 else 20)
|
||||
|
||||
# Warmup
|
||||
for _ in range(n_warmup):
|
||||
alltoallv.all_to_all_single(input_tensor, output=output_tensor)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start = time.perf_counter()
|
||||
for _ in range(n_iters):
|
||||
alltoallv.all_to_all_single(input_tensor, output=output_tensor)
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
total_bytes = 2 * input_size * n_iters # read + write
|
||||
bandwidth_gbps = total_bytes / elapsed / 1e9
|
||||
latency_us = elapsed / n_iters * 1e6
|
||||
|
||||
if rank == 0:
|
||||
if msg_size >= 1024 * 1024:
|
||||
size_str = f"{msg_size // (1024*1024)}MB"
|
||||
elif msg_size >= 1024:
|
||||
size_str = f"{msg_size // 1024}KB"
|
||||
else:
|
||||
size_str = f"{msg_size}B"
|
||||
print(f" {size_str:>10s} {n_iters:>5d} {elapsed*1000:>10.2f} {latency_us:>10.1f} {bandwidth_gbps:>10.2f}")
|
||||
|
||||
# Cleanup
|
||||
dist.barrier()
|
||||
|
||||
Reference in New Issue
Block a user