mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Merge multinode branch
This commit is contained in:
@@ -62,6 +62,7 @@ void register_algorithm(nb::module_& m) {
|
||||
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
|
||||
.def_prop_ro("constraint", &Algorithm::constraint)
|
||||
.def_prop_ro("type", &Algorithm::type)
|
||||
.def("reset", &Algorithm::reset)
|
||||
.def(
|
||||
"execute",
|
||||
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
|
||||
|
||||
@@ -248,6 +248,10 @@ class MscclppAlltoAllV:
|
||||
# Fast path: skip GPU copies + bootstrap exchange if split sizes unchanged
|
||||
splits_key = (tuple(send_counts_bytes), tuple(recv_counts_bytes))
|
||||
if splits_key != self._cached_splits_key:
|
||||
# Clear cached contexts to free RegisteredMemory for old (possibly freed) tensors.
|
||||
# Without this, stale CUDA IPC handles accumulate and eventually SIGSEGV.
|
||||
if hasattr(self._algo, 'reset'):
|
||||
self._algo.reset()
|
||||
# Copy counts/displacements to GPU
|
||||
self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64))
|
||||
self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64))
|
||||
@@ -268,13 +272,16 @@ class MscclppAlltoAllV:
|
||||
stream = torch.cuda.current_stream()
|
||||
cuda_stream = stream.cuda_stream
|
||||
|
||||
# Use full buffer sizes (not actual data sizes) so the C++ context
|
||||
# key (input_ptr, output_ptr, inputSize, outputSize) is always the
|
||||
# same when using persistent buffers. This ensures only ONE context
|
||||
# is ever created, avoiding bootstrap TCP on every unique size combo.
|
||||
# The kernel uses per-peer sendCounts/recvCounts for actual data bounds.
|
||||
input_size = input.numel() * elem_size
|
||||
output_size = output.numel() * elem_size
|
||||
# Use the full underlying storage size (not just the view's active data)
|
||||
# for the context key, so that reusing views of the same tensor with
|
||||
# different split sizes doesn't create new contexts (which leak
|
||||
# RegisteredMemory for stale buffers).
|
||||
try:
|
||||
input_alloc_size = input.untyped_storage().size()
|
||||
output_alloc_size = output.untyped_storage().size()
|
||||
except Exception:
|
||||
input_alloc_size = input.nelement() * input.element_size()
|
||||
output_alloc_size = output.nelement() * output.element_size()
|
||||
|
||||
self._a2av_call_count += 1
|
||||
_cid = self._a2av_call_count
|
||||
@@ -297,8 +304,8 @@ class MscclppAlltoAllV:
|
||||
self._comm,
|
||||
input.data_ptr(),
|
||||
output.data_ptr(),
|
||||
input_size,
|
||||
output_size,
|
||||
input_alloc_size,
|
||||
output_alloc_size,
|
||||
_torch_dtype_to_mscclpp(dtype),
|
||||
ReduceOp.NOP,
|
||||
cuda_stream,
|
||||
|
||||
@@ -12,30 +12,113 @@ Usage:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import socket
|
||||
import struct
|
||||
import pickle
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
# Must init torch.distributed before importing mscclpp modules
|
||||
# to set rank/world_size environment variables
|
||||
|
||||
|
||||
def _get_routable_ip() -> str:
|
||||
"""Get a routable IP address for this host (not 127.0.0.1)."""
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 80)) # doesn't actually send data
|
||||
ip = s.getsockname()[0]
|
||||
s.close()
|
||||
return ip
|
||||
except Exception:
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
|
||||
|
||||
def _tcp_broadcast_unique_id(unique_id_bytes: bytes, rank: int, world_size: int,
|
||||
master_addr: str, port: int = 18515) -> bytes:
|
||||
"""Broadcast UniqueId bytes from rank 0 to all other ranks via TCP."""
|
||||
if rank == 0:
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind(("", port))
|
||||
server.listen(world_size - 1)
|
||||
for _ in range(world_size - 1):
|
||||
conn, _ = server.accept()
|
||||
length = len(unique_id_bytes)
|
||||
conn.sendall(struct.pack("!I", length) + unique_id_bytes)
|
||||
conn.close()
|
||||
server.close()
|
||||
return unique_id_bytes
|
||||
else:
|
||||
# Retry connection to rank 0
|
||||
for attempt in range(120):
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.connect((master_addr, port))
|
||||
break
|
||||
except (ConnectionRefusedError, OSError):
|
||||
time.sleep(0.5)
|
||||
if attempt == 119:
|
||||
raise RuntimeError(f"Rank {rank}: failed to connect to {master_addr}:{port}")
|
||||
raw_len = b""
|
||||
while len(raw_len) < 4:
|
||||
raw_len += s.recv(4 - len(raw_len))
|
||||
length = struct.unpack("!I", raw_len)[0]
|
||||
data = b""
|
||||
while len(data) < length:
|
||||
data += s.recv(length - len(data))
|
||||
s.close()
|
||||
return data
|
||||
|
||||
|
||||
def main():
|
||||
# Get rank/world from MPI environment
|
||||
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0)))
|
||||
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", os.environ.get("PMI_SIZE", 1)))
|
||||
|
||||
# Set CUDA device
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))
|
||||
# Set CUDA device — prefer MPI-provided local rank to handle any rank mapping
|
||||
local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK",
|
||||
os.environ.get("MPI_LOCALRANKID",
|
||||
os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# Initialize torch.distributed with NCCL (need MASTER_ADDR/PORT)
|
||||
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
||||
# Disable UCX in OpenMPI to avoid version mismatch crashes
|
||||
os.environ.setdefault("OMPI_MCA_pml", "ob1")
|
||||
os.environ.setdefault("OMPI_MCA_btl", "tcp,vader,self")
|
||||
|
||||
# Initialize torch.distributed — use NCCL when torch_fn benchmarks are needed,
|
||||
# otherwise gloo avoids IB configuration issues on some clusters.
|
||||
# Set ALLTOALLV_BACKEND=nccl to enable torch baseline comparison.
|
||||
backend = os.environ.get("ALLTOALLV_BACKEND", "gloo")
|
||||
# For multi-node: MASTER_ADDR must be set to rank 0's routable IP.
|
||||
# Single-node auto-detects; multi-node requires it from the launcher.
|
||||
if "MASTER_ADDR" not in os.environ:
|
||||
if rank == 0:
|
||||
os.environ["MASTER_ADDR"] = _get_routable_ip()
|
||||
else:
|
||||
# Check if we're single-node (all ranks on same host)
|
||||
n_gpus = torch.cuda.device_count()
|
||||
if world_size <= n_gpus:
|
||||
# Likely single-node – 127.0.0.1 works
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Rank {rank}: MASTER_ADDR not set for multi-node run "
|
||||
f"(world_size={world_size} > local GPUs={n_gpus}). "
|
||||
f"Set it in your launcher, e.g.:\n"
|
||||
f" mpirun -x MASTER_ADDR=<node0_ip> -x MASTER_PORT=29500 ..."
|
||||
)
|
||||
os.environ.setdefault("MASTER_PORT", "29500")
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size,
|
||||
device_id=torch.device(f"cuda:{local_rank}"))
|
||||
if backend == "nccl":
|
||||
# Don't use device_id= eager init — it triggers an immediate NCCL allreduce
|
||||
# that fails on some platforms (e.g. GB200 with NCCL 2.28.9).
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
else:
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
if rank == 0:
|
||||
print(f"Testing MscclppAlltoAllV with {world_size} ranks")
|
||||
@@ -48,33 +131,51 @@ def main():
|
||||
UniqueId,
|
||||
)
|
||||
from mscclpp.ext.alltoallv_single import MscclppAlltoAllV
|
||||
import pickle
|
||||
|
||||
# Create mscclpp communicator with TcpBootstrap
|
||||
# Use torch.distributed to share the unique ID via pickle
|
||||
# Broadcast UniqueId via TCP sockets (avoids mpi4py/UCX issues)
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
uid_port = int(os.environ.get("MSCCLPP_UID_PORT", "18515"))
|
||||
bootstrap = TcpBootstrap(rank, world_size)
|
||||
|
||||
if rank == 0:
|
||||
unique_id = bootstrap.create_unique_id()
|
||||
# Serialize UniqueId via pickle and broadcast
|
||||
pickled = pickle.dumps(unique_id)
|
||||
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
|
||||
id_tensor[:len(pickled)] = torch.tensor(list(pickled), dtype=torch.uint8)
|
||||
# Also send length
|
||||
len_tensor = torch.tensor([len(pickled)], dtype=torch.int64, device='cuda')
|
||||
uid_bytes = pickle.dumps(unique_id)
|
||||
else:
|
||||
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
|
||||
len_tensor = torch.zeros(1, dtype=torch.int64, device='cuda')
|
||||
|
||||
dist.broadcast(len_tensor, src=0)
|
||||
dist.broadcast(id_tensor, src=0)
|
||||
uid_bytes = b""
|
||||
|
||||
uid_bytes = _tcp_broadcast_unique_id(uid_bytes, rank, world_size, master_addr, uid_port)
|
||||
if rank != 0:
|
||||
pickled_len = int(len_tensor.item())
|
||||
pickled = bytes(id_tensor[:pickled_len].cpu().tolist())
|
||||
unique_id = pickle.loads(pickled)
|
||||
unique_id = pickle.loads(uid_bytes)
|
||||
|
||||
bootstrap.initialize(unique_id)
|
||||
|
||||
# ── Multi-node diagnostics ─────────────────────────────────────────
|
||||
import subprocess, platform
|
||||
hostname = platform.node()
|
||||
n_ranks_per_node = bootstrap.get_n_ranks_per_node()
|
||||
is_multi_node = (world_size > n_ranks_per_node)
|
||||
|
||||
# Check IB device availability
|
||||
try:
|
||||
ib_out = subprocess.check_output(["ibv_devinfo", "-l"], stderr=subprocess.DEVNULL, timeout=5).decode().strip()
|
||||
ib_devices = [l.strip() for l in ib_out.splitlines() if l.strip() and "device" not in l.lower()]
|
||||
except Exception:
|
||||
ib_devices = []
|
||||
|
||||
if rank == 0:
|
||||
print(f" Hostname: {hostname}")
|
||||
print(f" nRanksPerNode: {n_ranks_per_node}, isMultiNode: {is_multi_node}")
|
||||
print(f" IB devices: {ib_devices if ib_devices else 'NONE FOUND'}")
|
||||
print(f" MSCCLPP_SOCKET_IFNAME: {os.environ.get('MSCCLPP_SOCKET_IFNAME', '<not set>')}")
|
||||
if is_multi_node and not ib_devices:
|
||||
print(f" WARNING: Multi-node detected but no IB devices! Cross-node will fail.")
|
||||
# Also print from rank n_ranks_per_node (first rank on node 1) for comparison
|
||||
if is_multi_node and rank == n_ranks_per_node:
|
||||
print(f" [Node 1] Hostname: {hostname}, rank={rank}")
|
||||
print(f" [Node 1] IB devices: {ib_devices if ib_devices else 'NONE FOUND'}")
|
||||
# ── End diagnostics ────────────────────────────────────────────────
|
||||
|
||||
comm = Communicator(bootstrap)
|
||||
|
||||
# Create MscclppAlltoAllV with existing communicator
|
||||
@@ -188,6 +289,21 @@ def main():
|
||||
input_split_sizes=in_splits,
|
||||
output_split_sizes=out_splits)
|
||||
|
||||
# Detect whether torch comparison is possible:
|
||||
# - Need NCCL backend (gloo doesn't support all_to_all_single)
|
||||
# - Need native NCCL (mscclpp shim doesn't implement all collectives)
|
||||
use_torch_baseline = (backend == "nccl")
|
||||
if use_torch_baseline:
|
||||
try:
|
||||
tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
dist.all_to_all_single(tiny_out, tiny_in)
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
use_torch_baseline = False
|
||||
if rank == 0:
|
||||
print(" [INFO] torch all_to_all_single unavailable, skipping torch baseline")
|
||||
|
||||
def torch_fn(inp, out, in_splits, out_splits):
|
||||
dist.all_to_all_single(out, inp,
|
||||
output_split_sizes=out_splits,
|
||||
@@ -202,22 +318,32 @@ def main():
|
||||
|
||||
def print_header():
|
||||
if rank == 0:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
|
||||
f"{'torch Lat':>10s} {'torch BW':>9s} "
|
||||
f"{'Speedup':>7s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11} "
|
||||
f"{'-'*10} {'-'*9} "
|
||||
f"{'-'*7}")
|
||||
if use_torch_baseline:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
|
||||
f"{'torch Lat':>10s} {'torch BW':>9s} "
|
||||
f"{'Speedup':>7s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11} "
|
||||
f"{'-'*10} {'-'*9} "
|
||||
f"{'-'*7}")
|
||||
else:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11}")
|
||||
|
||||
def print_row(size_str, m_lat, m_bw, t_lat, t_bw):
|
||||
def print_row(size_str, m_lat, m_bw, t_lat=None, t_bw=None):
|
||||
if rank == 0:
|
||||
speedup = m_bw / t_bw if t_bw > 0 else float('inf')
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
|
||||
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
|
||||
f"{speedup:>6.2f}x")
|
||||
if t_bw is not None and t_bw > 0:
|
||||
speedup = m_bw / t_bw
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
|
||||
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
|
||||
f"{speedup:>6.2f}x")
|
||||
else:
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB")
|
||||
|
||||
# ── Test 3: Synthetic variable-size sweep ─────────────────────────────
|
||||
if rank == 0:
|
||||
@@ -227,6 +353,13 @@ def main():
|
||||
msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0]
|
||||
msg_sizes.append(128 * 1024 * 1024)
|
||||
|
||||
# Pre-compute max split sizes across all sweep iterations to allocate
|
||||
# fixed-size tensors. Reusing the same tensors keeps the NativeAlgorithm
|
||||
# context key stable (same ptrs + sizes) and avoids the context cache
|
||||
# leak that causes SIGSEGV when stale RegisteredMemory accumulates.
|
||||
max_in_elems = 0
|
||||
max_out_elems = 0
|
||||
sweep_params = [] # (avg_msg_size, in_splits, out_splits)
|
||||
for avg_msg_size in msg_sizes:
|
||||
random.seed(12345)
|
||||
avg_elems = avg_msg_size // 4
|
||||
@@ -234,19 +367,41 @@ def main():
|
||||
for i in range(world_size):
|
||||
row = [max(1, int(avg_elems * (0.5 + random.random()))) for _ in range(world_size)]
|
||||
send_matrix.append(row)
|
||||
|
||||
in_splits = send_matrix[rank]
|
||||
out_splits = [send_matrix[j][rank] for j in range(world_size)]
|
||||
max_in_elems = max(max_in_elems, sum(in_splits))
|
||||
max_out_elems = max(max_out_elems, sum(out_splits))
|
||||
sweep_params.append((avg_msg_size, in_splits, out_splits))
|
||||
|
||||
inp = torch.randn(sum(in_splits), dtype=torch.float32, device='cuda')
|
||||
out = torch.empty(sum(out_splits), dtype=torch.float32, device='cuda')
|
||||
# Allocate once at max size
|
||||
inp = torch.randn(max_in_elems, dtype=torch.float32, device='cuda')
|
||||
out = torch.empty(max_out_elems, dtype=torch.float32, device='cuda')
|
||||
|
||||
for avg_msg_size, in_splits, out_splits in sweep_params:
|
||||
n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5
|
||||
n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20)
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
|
||||
# Use views into the fixed buffers (same data_ptr → same context key)
|
||||
inp_view = inp[:sum(in_splits)]
|
||||
out_view = out[:sum(out_splits)]
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
|
||||
if use_torch_baseline:
|
||||
try:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f" [WARN] torch baseline failed: {e}")
|
||||
print(f" [INFO] Disabling torch baseline for remaining sizes")
|
||||
use_torch_baseline = False
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
|
||||
else:
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
|
||||
|
||||
# ── Test 4: Real MoE workloads ───────────────────────────────────────
|
||||
# Token counts from real MoE training runs (rank 0's view, 8 GPUs).
|
||||
@@ -257,19 +412,30 @@ def main():
|
||||
# per rank so every rank has the same total send and each NVLink
|
||||
# carries a realistically imbalanced load.
|
||||
|
||||
# 10 workloads picked from 3M dispatch records in a real MoE training run,
|
||||
# covering the full imbalance spectrum from nearly uniform (1.05×) to
|
||||
# extremely skewed (10×). Each has 32768 total tokens → 167.8MB.
|
||||
MOE_WORKLOADS = [
|
||||
{
|
||||
"name": "MoE-A",
|
||||
# input_splits=[3976,3916,4497,4838,2888,3839,4355,4459]
|
||||
# total_send=167,772,160 total_recv=148,316,160
|
||||
"input_tokens": [3976, 3916, 4497, 4838, 2888, 3839, 4355, 4459],
|
||||
},
|
||||
{
|
||||
"name": "MoE-B",
|
||||
# input_splits=[3009,7161,2719,2766,3428,3010,6290,4385]
|
||||
# total_send=167,772,160 total_recv=163,722,240
|
||||
"input_tokens": [3009, 7161, 2719, 2766, 3428, 3010, 6290, 4385],
|
||||
},
|
||||
{"name": "MoE-A", # imbalance ≈ 1.05× (near-uniform)
|
||||
"input_tokens": [4122, 4115, 4000, 4200, 4126, 4046, 4035, 4124]},
|
||||
{"name": "MoE-B", # imbalance ≈ 1.20×
|
||||
"input_tokens": [3770, 4236, 3966, 4046, 4524, 4132, 3825, 4269]},
|
||||
{"name": "MoE-C", # imbalance ≈ 1.35×
|
||||
"input_tokens": [4142, 4489, 4563, 3380, 3957, 4133, 3958, 4146]},
|
||||
{"name": "MoE-D", # imbalance ≈ 1.50× (median)
|
||||
"input_tokens": [4232, 3697, 4619, 4788, 4420, 3192, 3971, 3849]},
|
||||
{"name": "MoE-E", # imbalance ≈ 1.75×
|
||||
"input_tokens": [4178, 3209, 4678, 5085, 3108, 3365, 5439, 3706]},
|
||||
{"name": "MoE-F", # imbalance ≈ 2.00×
|
||||
"input_tokens": [4582, 3903, 3949, 3727, 4823, 5106, 2553, 4125]},
|
||||
{"name": "MoE-G", # imbalance ≈ 2.50×
|
||||
"input_tokens": [4036, 4438, 4804, 6180, 2913, 2472, 4105, 3820]},
|
||||
{"name": "MoE-H", # imbalance ≈ 3.50×
|
||||
"input_tokens": [3152, 1722, 4406, 4027, 5365, 6027, 4895, 3174]},
|
||||
{"name": "MoE-I", # imbalance ≈ 5.00×
|
||||
"input_tokens": [4384, 4194, 7840, 3079, 3460, 3506, 1568, 4737]},
|
||||
{"name": "MoE-J", # imbalance ≈ 10.00× (extreme skew)
|
||||
"input_tokens": [2710, 7661, 3354, 4457, 4609, 766, 3423, 5788]},
|
||||
]
|
||||
ELEMS_PER_TOKEN = 2560 # 5120 bytes / 2 bytes-per-bfloat16
|
||||
|
||||
@@ -304,10 +470,23 @@ def main():
|
||||
n_warmup, n_iters = 5, 20
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
|
||||
avg_bytes = total_bytes // world_size
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
if use_torch_baseline:
|
||||
try:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f" [WARN] torch baseline failed: {e}")
|
||||
print(f" [INFO] Disabling torch baseline for remaining workloads")
|
||||
use_torch_baseline = False
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw)
|
||||
else:
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw)
|
||||
else:
|
||||
if rank == 0:
|
||||
print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)")
|
||||
|
||||
@@ -157,12 +157,53 @@ RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (transports.has(Transport::CudaIpc)) {
|
||||
} else if (transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash) {
|
||||
auto entry = getTransportInfo(Transport::CudaIpc);
|
||||
auto gpuIpcMem = GpuIpcMem::create(entry.gpuIpcMemHandle);
|
||||
// Create a memory map for the remote GPU memory. The memory map will keep the GpuIpcMem instance alive.
|
||||
this->remoteMemMap = gpuIpcMem->map();
|
||||
this->data = this->remoteMemMap.get();
|
||||
} else if (transports.has(Transport::CudaIpc) && getHostHash() != this->hostHash) {
|
||||
// Cross-node CudaIpc: try available handle types in order of preference.
|
||||
// On GB200 NVSwitch, both Fabric and RuntimeIpc handles work cross-node.
|
||||
// On H100 (no NVSwitch across nodes), none of these will work.
|
||||
auto entry = getTransportInfo(Transport::CudaIpc);
|
||||
bool mapped = false;
|
||||
|
||||
// 1) Try Fabric handle first (works on any NVSwitch-connected system)
|
||||
if (!mapped && (entry.gpuIpcMemHandle.typeFlags & GpuIpcMemHandle::Type::Fabric)) {
|
||||
GpuIpcMemHandle fabricOnlyHandle = entry.gpuIpcMemHandle;
|
||||
fabricOnlyHandle.typeFlags = GpuIpcMemHandle::Type::Fabric;
|
||||
try {
|
||||
auto gpuIpcMem = GpuIpcMem::create(fabricOnlyHandle);
|
||||
this->remoteMemMap = gpuIpcMem->map();
|
||||
this->data = this->remoteMemMap.get();
|
||||
mapped = true;
|
||||
INFO(GPU, "Mapped cross-node CudaIpc memory via Fabric handle at pointer ", this->data);
|
||||
} catch (const std::exception& e) {
|
||||
INFO(GPU, "Fabric handle mapping failed (will try RuntimeIpc): ", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Try RuntimeIpc handle (cudaIpcOpenMemHandle — works on GB200 NVSwitch cross-node)
|
||||
if (!mapped && (entry.gpuIpcMemHandle.typeFlags & GpuIpcMemHandle::Type::RuntimeIpc)) {
|
||||
GpuIpcMemHandle runtimeOnlyHandle = entry.gpuIpcMemHandle;
|
||||
runtimeOnlyHandle.typeFlags = GpuIpcMemHandle::Type::RuntimeIpc;
|
||||
try {
|
||||
auto gpuIpcMem = GpuIpcMem::create(runtimeOnlyHandle);
|
||||
this->remoteMemMap = gpuIpcMem->map();
|
||||
this->data = this->remoteMemMap.get();
|
||||
mapped = true;
|
||||
INFO(GPU, "Mapped cross-node CudaIpc memory via RuntimeIpc handle at pointer ", this->data);
|
||||
} catch (const std::exception& e) {
|
||||
INFO(GPU, "RuntimeIpc handle mapping failed for cross-node peer: ", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
if (!mapped) {
|
||||
WARN(GPU, "Skipping CudaIpc map for cross-node peer (all handle types failed, local hostHash=",
|
||||
getHostHash(), ", remote hostHash=", this->hostHash, ")");
|
||||
}
|
||||
}
|
||||
if (this->data != nullptr) {
|
||||
INFO(GPU, "Opened CUDA IPC handle at pointer ", this->data);
|
||||
|
||||
@@ -8,9 +8,13 @@
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include "debug.h"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
@@ -21,17 +25,38 @@ namespace collective {
|
||||
#define ALLTOALLV_WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
using MultiNodeMode = AlltoallvFullmesh::MultiNodeMode;
|
||||
|
||||
// Context to hold all necessary state for alltoallv execution
|
||||
struct AllToAllVContext {
|
||||
int rank;
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
|
||||
// MemoryChannel (CudaIpc) — used for intra-node (always) and cross-node (NVSwitch mode)
|
||||
std::vector<RegisteredMemory> registeredMemories;
|
||||
std::vector<MemoryChannel> memoryChannels;
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
|
||||
std::shared_ptr<DeviceHandle<MemoryChannel>> memoryChannelDeviceHandles;
|
||||
std::shared_ptr<DeviceSyncer> deviceSyncer; // GPU-allocated, for multi-block grid sync
|
||||
|
||||
// PortChannel (IB) — used for cross-node peers in IB mode only
|
||||
std::shared_ptr<ProxyService> proxyService;
|
||||
std::vector<PortChannel> portChannels;
|
||||
std::shared_ptr<PortChannelDeviceHandle> portChannelDeviceHandles;
|
||||
|
||||
// Peer locality map (IB mode only)
|
||||
std::shared_ptr<int> d_peerIsLocal; // GPU array [nPeers]
|
||||
std::shared_ptr<int> d_peerToPortChannelIdx; // GPU array [nPeers]
|
||||
|
||||
// Staging buffers (NVSwitch mode only): allocated via GpuBuffer (cuMemCreate → Fabric handles)
|
||||
bool useStaging;
|
||||
std::shared_ptr<GpuBuffer<char>> inputStaging;
|
||||
std::shared_ptr<GpuBuffer<char>> outputStaging;
|
||||
|
||||
// Which kernel dispatch path to use
|
||||
AlltoallvFullmesh::MultiNodeMode mode;
|
||||
|
||||
std::shared_ptr<DeviceSyncer> deviceSyncer;
|
||||
};
|
||||
|
||||
AlltoallvFullmesh::~AlltoallvFullmesh() = default;
|
||||
@@ -68,12 +93,46 @@ std::shared_ptr<Algorithm> AlltoallvFullmesh::build() {
|
||||
|
||||
void AlltoallvFullmesh::initialize(std::shared_ptr<Communicator> comm) {
|
||||
worldSize_ = comm->bootstrap()->getNranks();
|
||||
this->conns_ = setupConnections(comm);
|
||||
int rank = comm->bootstrap()->getRank();
|
||||
int nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
int localGpuIdx = rank % nRanksPerNode;
|
||||
bool isMultiNode = (worldSize_ > nRanksPerNode);
|
||||
bool nvlsSupported = isNvlsSupported();
|
||||
int ibDevCount = getIBDeviceCount();
|
||||
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] initialize: worldSize=%d, nRanksPerNode=%d, "
|
||||
"isMultiNode=%d, isNvlsSupported=%d, ibDevCount=%d, localGpuIdx=%d",
|
||||
rank, worldSize_, nRanksPerNode, isMultiNode, nvlsSupported, ibDevCount, localGpuIdx);
|
||||
|
||||
if (!isMultiNode) {
|
||||
multiNodeMode_ = MultiNodeMode::SingleNode;
|
||||
this->conns_ = setupConnections(comm);
|
||||
} else if (nvlsSupported) {
|
||||
multiNodeMode_ = MultiNodeMode::NVSwitch;
|
||||
this->conns_ = setupConnections(comm);
|
||||
} else {
|
||||
if (ibDevCount <= 0) {
|
||||
throw Error("Multi-node alltoallv requires IB transport but no IB devices found. "
|
||||
"Ensure IB drivers are loaded and devices are available.",
|
||||
ErrorCode::InvalidUsage);
|
||||
}
|
||||
multiNodeMode_ = MultiNodeMode::IB;
|
||||
this->conns_ = setupHybridConnections(comm, localGpuIdx);
|
||||
}
|
||||
|
||||
const char* modeStr = (multiNodeMode_ == MultiNodeMode::SingleNode) ? "SingleNode" :
|
||||
(multiNodeMode_ == MultiNodeMode::NVSwitch) ? "NVSwitch" : "IB";
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] mode=%s, connections=%zu",
|
||||
rank, modeStr, this->conns_.size());
|
||||
for (size_t i = 0; i < this->conns_.size(); ++i) {
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] conn[%zu] transport=%d",
|
||||
rank, i, (int)this->conns_[i].transport());
|
||||
}
|
||||
}
|
||||
|
||||
CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream,
|
||||
[[maybe_unused]] size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream,
|
||||
[[maybe_unused]] int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
|
||||
@@ -103,44 +162,71 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
// Use maximum threads (1024) for best bandwidth utilization
|
||||
const int threadsPerBlock = (nThreadsPerBlock > 0 && nThreadsPerBlock <= 1024) ? nThreadsPerBlock : 1024;
|
||||
|
||||
// Peer-parallel algorithm: blocks assigned round-robin to peers so ALL
|
||||
// NVLink connections are active simultaneously. Critical for 4+ GPU systems.
|
||||
//
|
||||
// Small messages (<1MB avg): nPeers blocks (1 per peer, no barrier)
|
||||
// Large messages (>=1MB avg): nPeers * blocksPerPeer (barrier-based)
|
||||
constexpr size_t SIZE_THRESHOLD = 1 << 20; // 1MB
|
||||
size_t avgMsgSize = inputSize / worldSize;
|
||||
int nPeers = worldSize - 1;
|
||||
if (nPeers < 1) nPeers = 1;
|
||||
|
||||
if (avgMsgSize < SIZE_THRESHOLD) {
|
||||
// Small messages: 1 block per peer, parallel signal/wait, no barrier
|
||||
// Determine send/recv buffer pointers.
|
||||
// NVSwitch mode: copy PyTorch data to/from GpuBuffer staging buffers.
|
||||
const void* sendBuff = input;
|
||||
void* recvBuff = output;
|
||||
|
||||
if (algoCtx->useStaging) {
|
||||
sendBuff = algoCtx->inputStaging->data();
|
||||
recvBuff = algoCtx->outputStaging->data();
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(
|
||||
const_cast<void*>(sendBuff), input,
|
||||
inputSize, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
if (algoCtx->mode == MultiNodeMode::IB) {
|
||||
// ── IB mode: PortChannel kernel for ALL peers ──────────────────────
|
||||
// PortChannel handles both CudaIpc (intra) and IB (inter) connections
|
||||
// via the ProxyService proxy thread.
|
||||
int numBlocks = nPeers;
|
||||
alltoallvPeerParallelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
algoCtx->deviceSyncer.get(),
|
||||
alltoallvPortChannelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->portChannelDeviceHandles.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
sendBuff, recvBuff,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls,
|
||||
d_remoteRecvDispls);
|
||||
} else {
|
||||
// Large messages: multiple blocks per peer for maximum put bandwidth.
|
||||
// Cap total blocks to avoid excessive barrier overhead.
|
||||
int blocksPerPeer = (nBlocks > 0 && nBlocks <= 128)
|
||||
? ((nBlocks + nPeers - 1) / nPeers) // user-specified total → per-peer
|
||||
: ALLTOALLV_DEFAULT_BLOCKS_PER_PEER;
|
||||
int numBlocks = nPeers * blocksPerPeer;
|
||||
if (numBlocks > 128) numBlocks = (128 / nPeers) * nPeers; // keep multiple of nPeers
|
||||
if (numBlocks < nPeers) numBlocks = nPeers;
|
||||
alltoallvPeerParallelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
algoCtx->deviceSyncer.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls,
|
||||
d_remoteRecvDispls);
|
||||
// ── SingleNode / NVSwitch mode: MemoryChannel kernel ───────────────
|
||||
constexpr size_t SIZE_THRESHOLD = 1 << 20; // 1MB
|
||||
size_t avgMsgSize = inputSize / worldSize;
|
||||
|
||||
if (avgMsgSize < SIZE_THRESHOLD) {
|
||||
int numBlocks = nPeers;
|
||||
alltoallvPeerParallelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
algoCtx->deviceSyncer.get(),
|
||||
rank, worldSize,
|
||||
sendBuff, recvBuff,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls,
|
||||
d_remoteRecvDispls);
|
||||
} else {
|
||||
int blocksPerPeer = (nBlocks > 0 && nBlocks <= 128)
|
||||
? ((nBlocks + nPeers - 1) / nPeers)
|
||||
: ALLTOALLV_DEFAULT_BLOCKS_PER_PEER;
|
||||
int numBlocks = nPeers * blocksPerPeer;
|
||||
if (numBlocks > 128) numBlocks = (128 / nPeers) * nPeers;
|
||||
if (numBlocks < nPeers) numBlocks = nPeers;
|
||||
alltoallvPeerParallelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
algoCtx->deviceSyncer.get(),
|
||||
rank, worldSize,
|
||||
sendBuff, recvBuff,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls,
|
||||
d_remoteRecvDispls);
|
||||
}
|
||||
}
|
||||
|
||||
if (algoCtx->useStaging) {
|
||||
MSCCLPP_CUDATHROW(cudaMemcpyAsync(
|
||||
output, recvBuff,
|
||||
outputSize, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
if (cudaGetLastError() == cudaSuccess) {
|
||||
@@ -157,37 +243,104 @@ std::shared_ptr<void> AlltoallvFullmesh::initAlltoallvContext(
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
ctx->mode = this->multiNodeMode_;
|
||||
ctx->useStaging = (ctx->mode == MultiNodeMode::NVSwitch);
|
||||
|
||||
// Register memories for input and output buffers
|
||||
RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, Transport::CudaIpc);
|
||||
RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, Transport::CudaIpc);
|
||||
int rank = ctx->rank;
|
||||
int localGpuIdx = rank % ctx->nRanksPerNode;
|
||||
const char* modeStr = (ctx->mode == MultiNodeMode::SingleNode) ? "SingleNode" :
|
||||
(ctx->mode == MultiNodeMode::NVSwitch) ? "NVSwitch" : "IB";
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] initContext: mode=%s, useStaging=%d, "
|
||||
"input=%p (%zu B), output=%p (%zu B), localGpuIdx=%d",
|
||||
rank, modeStr, ctx->useStaging, input, inputSize, output, outputSize, localGpuIdx);
|
||||
|
||||
// Exchange output buffer registration with all peers (we write to peer's output buffer)
|
||||
std::vector<RegisteredMemory> remoteOutputMemories = setupRemoteMemories(comm, ctx->rank, outputBufRegMem);
|
||||
if (ctx->mode == MultiNodeMode::NVSwitch) {
|
||||
// ── NVSwitch (GB200): staging GpuBuffers + CudaIpc MemoryChannel for all peers
|
||||
ctx->inputStaging = std::make_shared<GpuBuffer<char>>(inputSize);
|
||||
ctx->outputStaging = std::make_shared<GpuBuffer<char>>(outputSize);
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch staging: input=%p (%zu B), output=%p (%zu B)",
|
||||
rank, ctx->inputStaging->data(), inputSize, ctx->outputStaging->data(), outputSize);
|
||||
|
||||
// Setup memory semaphores for synchronization (1 channel per peer)
|
||||
constexpr int nChannelsPerConnection = 1;
|
||||
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);
|
||||
TransportFlags allTransports = Transport::CudaIpc;
|
||||
RegisteredMemory inputBufRegMem = comm->registerMemory(
|
||||
ctx->inputStaging->data(), ctx->inputStaging->bytes(), allTransports);
|
||||
RegisteredMemory outputBufRegMem = comm->registerMemory(
|
||||
ctx->outputStaging->data(), ctx->outputStaging->bytes(), allTransports);
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: registered input=%p, output=%p",
|
||||
rank, inputBufRegMem.data(), outputBufRegMem.data());
|
||||
|
||||
// Setup memory channels: we read from our input buffer, write to peer's output buffer
|
||||
ctx->memoryChannels = setupMemoryChannels(
|
||||
this->conns_,
|
||||
ctx->memorySemaphores,
|
||||
remoteOutputMemories, // remote output buffers (where we write)
|
||||
inputBufRegMem, // local input buffer (where we read from)
|
||||
nChannelsPerConnection);
|
||||
std::vector<RegisteredMemory> remoteOutputMemories = setupRemoteMemories(comm, rank, outputBufRegMem);
|
||||
for (size_t i = 0; i < remoteOutputMemories.size(); ++i) {
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: remoteOutput[%zu] data=%p, size=%zu",
|
||||
rank, i, remoteOutputMemories[i].data(), remoteOutputMemories[i].size());
|
||||
if (remoteOutputMemories[i].data() == nullptr) {
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] ERROR: remoteOutput[%zu] has NULL data pointer! "
|
||||
"Cross-node CudaIpc mapping failed.", rank, i);
|
||||
}
|
||||
}
|
||||
|
||||
// Setup device handles
|
||||
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
|
||||
constexpr int nChannelsPerConnection = 1;
|
||||
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: %zu semaphores created",
|
||||
rank, ctx->memorySemaphores.size());
|
||||
ctx->memoryChannels = setupMemoryChannels(
|
||||
this->conns_, ctx->memorySemaphores, remoteOutputMemories, inputBufRegMem, nChannelsPerConnection);
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] NVSwitch: %zu memoryChannels created",
|
||||
rank, ctx->memoryChannels.size());
|
||||
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
|
||||
|
||||
// Allocate GPU DeviceSyncer for multi-block grid-wide barrier (zero-initialized)
|
||||
ctx->registeredMemories = std::move(remoteOutputMemories);
|
||||
ctx->registeredMemories.push_back(inputBufRegMem);
|
||||
ctx->registeredMemories.push_back(outputBufRegMem);
|
||||
|
||||
} else if (ctx->mode == MultiNodeMode::IB) {
|
||||
// ── IB: PortChannel for ALL peers (CudaIpc intra + IB inter connections)
|
||||
TransportFlags allTransports = Transport::CudaIpc | getIBTransportForGpu(localGpuIdx);
|
||||
RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, allTransports);
|
||||
RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, allTransports);
|
||||
|
||||
std::vector<RegisteredMemory> remoteOutputMemories = setupRemoteMemories(comm, rank, outputBufRegMem);
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: input=%p (%zu B), output=%p (%zu B), remotes=%zu",
|
||||
rank, input, inputSize, output, outputSize, remoteOutputMemories.size());
|
||||
for (size_t i = 0; i < remoteOutputMemories.size(); ++i) {
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: remoteOutput[%zu] data=%p, size=%zu",
|
||||
rank, i, remoteOutputMemories[i].data(), remoteOutputMemories[i].size());
|
||||
}
|
||||
|
||||
ctx->proxyService = std::make_shared<ProxyService>();
|
||||
ctx->portChannels = setupAllPortChannels(
|
||||
ctx->proxyService, *comm, this->conns_, remoteOutputMemories, inputBufRegMem);
|
||||
ctx->portChannelDeviceHandles = setupPortChannelDeviceHandles(ctx->portChannels);
|
||||
ctx->proxyService->startProxy(true);
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: %zu portChannels created, proxy started",
|
||||
rank, ctx->portChannels.size());
|
||||
|
||||
ctx->registeredMemories = std::move(remoteOutputMemories);
|
||||
ctx->registeredMemories.push_back(inputBufRegMem);
|
||||
ctx->registeredMemories.push_back(outputBufRegMem);
|
||||
|
||||
} else {
|
||||
// ── SingleNode: CudaIpc MemoryChannel (direct PyTorch buffers)
|
||||
TransportFlags allTransports = Transport::CudaIpc;
|
||||
RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, allTransports);
|
||||
RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, allTransports);
|
||||
|
||||
std::vector<RegisteredMemory> remoteOutputMemories = setupRemoteMemories(comm, rank, outputBufRegMem);
|
||||
|
||||
constexpr int nChannelsPerConnection = 1;
|
||||
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);
|
||||
ctx->memoryChannels = setupMemoryChannels(
|
||||
this->conns_, ctx->memorySemaphores, remoteOutputMemories, inputBufRegMem, nChannelsPerConnection);
|
||||
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
|
||||
|
||||
ctx->registeredMemories = std::move(remoteOutputMemories);
|
||||
ctx->registeredMemories.push_back(inputBufRegMem);
|
||||
ctx->registeredMemories.push_back(outputBufRegMem);
|
||||
}
|
||||
|
||||
// Allocate GPU DeviceSyncer for multi-block grid-wide barrier
|
||||
ctx->deviceSyncer = mscclpp::detail::gpuCallocShared<DeviceSyncer>();
|
||||
|
||||
// Keep registered memory references to prevent deallocation
|
||||
ctx->registeredMemories = std::move(remoteOutputMemories);
|
||||
ctx->registeredMemories.push_back(inputBufRegMem);
|
||||
ctx->registeredMemories.push_back(outputBufRegMem);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/switch_channel.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
@@ -31,11 +33,17 @@ std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
|
||||
const std::vector<mscclpp::RegisteredMemory>& remoteMemories, mscclpp::RegisteredMemory localMemory,
|
||||
int nChannelsPerConnection) {
|
||||
std::vector<mscclpp::MemoryChannel> channels;
|
||||
size_t nConnections = connections.size();
|
||||
// Count number of CudaIpc connections for proper dense indexing into memorySemaphores
|
||||
size_t nCudaIpcConns = 0;
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) nCudaIpcConns++;
|
||||
}
|
||||
for (int idx = 0; idx < nChannelsPerConnection; ++idx) {
|
||||
for (size_t cid = 0; cid < nConnections; ++cid) {
|
||||
size_t semIdx = 0;
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid].transport() == mscclpp::Transport::CudaIpc) {
|
||||
channels.emplace_back(memorySemaphores[idx * nConnections + cid], remoteMemories[cid], localMemory, nullptr);
|
||||
channels.emplace_back(memorySemaphores[idx * nCudaIpcConns + semIdx], remoteMemories[cid], localMemory, nullptr);
|
||||
semIdx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -54,6 +62,100 @@ std::vector<mscclpp::Connection> setupConnections(std::shared_ptr<mscclpp::Commu
|
||||
return connections;
|
||||
}
|
||||
|
||||
// IB device array — GPU index maps to its dedicated IB device
|
||||
static const mscclpp::Transport IBs[] = {
|
||||
mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, mscclpp::Transport::IB3,
|
||||
mscclpp::Transport::IB4, mscclpp::Transport::IB5, mscclpp::Transport::IB6, mscclpp::Transport::IB7,
|
||||
};
|
||||
|
||||
mscclpp::Transport getIBTransportForGpu(int localGpuIdx) {
|
||||
int ibCount = mscclpp::getIBDeviceCount();
|
||||
if (ibCount <= 0) {
|
||||
throw std::runtime_error("No IB devices available for inter-node communication");
|
||||
}
|
||||
int idx = localGpuIdx % ibCount;
|
||||
return IBs[idx];
|
||||
}
|
||||
|
||||
std::vector<mscclpp::Connection> setupHybridConnections(std::shared_ptr<mscclpp::Communicator> comm,
|
||||
int localGpuIdx) {
|
||||
int rank = comm->bootstrap()->getRank();
|
||||
int worldSize = comm->bootstrap()->getNranks();
|
||||
int nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
int thisNode = rank / nRanksPerNode;
|
||||
|
||||
bool hasIB = mscclpp::getIBDeviceCount() > 0;
|
||||
mscclpp::Transport ibTransport = hasIB ? getIBTransportForGpu(localGpuIdx) : mscclpp::Transport::CudaIpc;
|
||||
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
for (int r = 0; r < worldSize; r++) {
|
||||
if (r == rank) continue;
|
||||
mscclpp::Transport transport;
|
||||
if (r / nRanksPerNode == thisNode) {
|
||||
transport = mscclpp::Transport::CudaIpc;
|
||||
} else {
|
||||
transport = ibTransport;
|
||||
}
|
||||
connectionFutures.push_back(comm->connect(transport, r));
|
||||
}
|
||||
|
||||
std::vector<mscclpp::Connection> connections;
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const auto& future) { return future.get(); });
|
||||
return connections;
|
||||
}
|
||||
|
||||
std::vector<mscclpp::PortChannel> setupPortChannels(
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService,
|
||||
mscclpp::Communicator& comm,
|
||||
const std::vector<mscclpp::Connection>& connections,
|
||||
const std::vector<mscclpp::RegisteredMemory>& remoteMemories,
|
||||
mscclpp::RegisteredMemory localMemory) {
|
||||
std::vector<mscclpp::PortChannel> channels;
|
||||
mscclpp::MemoryId srcMemId = proxyService->addMemory(localMemory);
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid].transport() != mscclpp::Transport::CudaIpc) {
|
||||
// IB connection → PortChannel
|
||||
mscclpp::SemaphoreId semId = proxyService->buildAndAddSemaphore(comm, connections[cid]);
|
||||
mscclpp::MemoryId dstMemId = proxyService->addMemory(remoteMemories[cid]);
|
||||
channels.emplace_back(proxyService->portChannel(semId, dstMemId, srcMemId));
|
||||
}
|
||||
}
|
||||
return channels;
|
||||
}
|
||||
|
||||
std::vector<mscclpp::PortChannel> setupAllPortChannels(
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService,
|
||||
mscclpp::Communicator& comm,
|
||||
const std::vector<mscclpp::Connection>& connections,
|
||||
const std::vector<mscclpp::RegisteredMemory>& remoteMemories,
|
||||
mscclpp::RegisteredMemory localMemory) {
|
||||
std::vector<mscclpp::PortChannel> channels;
|
||||
mscclpp::MemoryId srcMemId = proxyService->addMemory(localMemory);
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
// Create PortChannel for EVERY connection (CudaIpc and IB alike).
|
||||
// The ProxyService proxy thread handles both connection types:
|
||||
// - CudaIpc: cudaMemcpyD2D via IPC-mapped pointer
|
||||
// - IB: RDMA write via ibv_post_send
|
||||
mscclpp::SemaphoreId semId = proxyService->buildAndAddSemaphore(comm, connections[cid]);
|
||||
mscclpp::MemoryId dstMemId = proxyService->addMemory(remoteMemories[cid]);
|
||||
channels.emplace_back(proxyService->portChannel(semId, dstMemId, srcMemId));
|
||||
}
|
||||
return channels;
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::PortChannelDeviceHandle> setupPortChannelDeviceHandles(
|
||||
const std::vector<mscclpp::PortChannel>& portChannels) {
|
||||
if (portChannels.empty()) return nullptr;
|
||||
std::vector<mscclpp::PortChannelDeviceHandle> handles;
|
||||
std::transform(portChannels.begin(), portChannels.end(), std::back_inserter(handles),
|
||||
[](const mscclpp::PortChannel& ch) { return ch.deviceHandle(); });
|
||||
auto ptr = mscclpp::detail::gpuCallocShared<mscclpp::PortChannelDeviceHandle>(handles.size());
|
||||
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
|
||||
ptr.get(), handles.data(), handles.size(), cudaMemcpyHostToDevice);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
|
||||
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<mscclpp::Connection>& connections,
|
||||
int nChannelsPerConnection) {
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -33,6 +35,9 @@ class AlltoallvFullmesh : public AlgorithmBuilder {
|
||||
|
||||
std::shared_ptr<Algorithm> build() override;
|
||||
|
||||
// Multi-node transport mode, decided at initialize() time
|
||||
enum class MultiNodeMode { SingleNode, NVSwitch, IB };
|
||||
|
||||
private:
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
|
||||
@@ -50,6 +55,7 @@ class AlltoallvFullmesh : public AlgorithmBuilder {
|
||||
|
||||
std::vector<Connection> conns_;
|
||||
int worldSize_;
|
||||
MultiNodeMode multiNodeMode_ = MultiNodeMode::SingleNode;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/copy_device.hpp>
|
||||
|
||||
@@ -29,6 +30,117 @@ constexpr int ALLTOALLV_DEFAULT_NBLOCKS = 24;
|
||||
// Controls how many thread blocks cooperate on each peer's data transfer.
|
||||
constexpr int ALLTOALLV_DEFAULT_BLOCKS_PER_PEER = 16;
|
||||
|
||||
/**
|
||||
* Hybrid AllToAllV kernel for multi-node: MemoryChannel (intra-node) + PortChannel (inter-node).
|
||||
*
|
||||
* Each block handles one peer (1 block per peer). For intra-node peers, all threads
|
||||
* cooperate on a MemoryChannel put (multi-threaded NVLink copy). For inter-node peers,
|
||||
* thread 0 pushes a PortChannel put descriptor to the CPU proxy FIFO (single-threaded),
|
||||
* which triggers an RDMA transfer.
|
||||
*
|
||||
* Key design points:
|
||||
* - MemoryChannel uses peerIdx-based dense indexing (only intra-node peers have MemoryChannels)
|
||||
* but we need the SAME peerIdx ordering as the connection array.
|
||||
* In practice, memoryChannels[] are created only for CudaIpc connections and are dense.
|
||||
* We use a separate peerToMemChIdx mapping from peerIsLocal.
|
||||
* - PortChannel uses separate dense indexing via peerToPortChannelIdx.
|
||||
* - Signal/wait is done per-peer by thread 0 of each block.
|
||||
*
|
||||
* Launch config: <<<nPeers, 1024>>>
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallvHybridKernel(DeviceHandle<MemoryChannel>* memoryChannels,
|
||||
PortChannelDeviceHandle* portChannels,
|
||||
const int* peerIsLocal,
|
||||
const int* peerToPortChannelIdx,
|
||||
DeviceSyncer* syncer,
|
||||
int rank,
|
||||
int worldSize,
|
||||
const void* sendBuff,
|
||||
void* recvBuff,
|
||||
const size_t* sendCounts,
|
||||
const size_t* sendDispls,
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls,
|
||||
const size_t* remoteRecvDispls) {
|
||||
const int nPeers = worldSize - 1;
|
||||
|
||||
// Handle trivial case (single rank)
|
||||
if (nPeers == 0) {
|
||||
const int gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int nThreads = blockDim.x * gridDim.x;
|
||||
if (sendCounts[rank] > 0) {
|
||||
mscclpp::copy((char*)recvBuff + recvDispls[rank],
|
||||
(void*)((const char*)sendBuff + sendDispls[rank]),
|
||||
sendCounts[rank], gtid, nThreads);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Phase 1: Local copy — all blocks cooperate using global thread IDs
|
||||
const int gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int nThreads = blockDim.x * gridDim.x;
|
||||
if (sendCounts[rank] > 0) {
|
||||
mscclpp::copy((char*)recvBuff + recvDispls[rank],
|
||||
(void*)((const char*)sendBuff + sendDispls[rank]),
|
||||
sendCounts[rank], gtid, nThreads);
|
||||
}
|
||||
|
||||
// Phase 2: Per-peer data transfer.
|
||||
// Each block handles one peer: blockIdx.x == peerIdx
|
||||
const int peerIdx = blockIdx.x;
|
||||
if (peerIdx >= nPeers) return;
|
||||
|
||||
const int peer = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
|
||||
if (peerIsLocal[peerIdx]) {
|
||||
// Intra-node: MemoryChannel — all threads cooperate on multi-threaded put
|
||||
// MemoryChannels are densely indexed for CudaIpc connections only.
|
||||
// We need to compute the MemoryChannel index from peerIdx.
|
||||
// Count how many local peers are before this peerIdx.
|
||||
int memChIdx = 0;
|
||||
for (int i = 0; i < peerIdx; i++) {
|
||||
if (peerIsLocal[i]) memChIdx++;
|
||||
}
|
||||
|
||||
if (sendCounts[peer] > 0) {
|
||||
memoryChannels[memChIdx].put(
|
||||
remoteRecvDispls[peer], // dst offset in peer's buffer
|
||||
sendDispls[peer], // src offset in our buffer
|
||||
sendCounts[peer], // size
|
||||
threadIdx.x, // thread id within block
|
||||
blockDim.x // total threads for this peer
|
||||
);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Signal and wait (thread 0 only)
|
||||
if (threadIdx.x == 0) {
|
||||
memoryChannels[memChIdx].signal();
|
||||
if (recvCounts[peer] > 0) {
|
||||
memoryChannels[memChIdx].wait();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Inter-node: PortChannel — single-threaded FIFO push
|
||||
int portChIdx = peerToPortChannelIdx[peerIdx];
|
||||
|
||||
if (threadIdx.x == 0 && sendCounts[peer] > 0) {
|
||||
portChannels[portChIdx].putWithSignalAndFlush(
|
||||
remoteRecvDispls[peer], // dst offset
|
||||
sendDispls[peer], // src offset
|
||||
sendCounts[peer] // size
|
||||
);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Wait for incoming data from remote peer
|
||||
if (threadIdx.x == 0 && recvCounts[peer] > 0) {
|
||||
portChannels[portChIdx].wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Peer-parallel AllToAllV kernel for maximum throughput with multiple GPUs.
|
||||
*
|
||||
@@ -400,6 +512,79 @@ __global__ void __launch_bounds__(1024)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* PortChannel-only AllToAllV kernel for multi-node.
|
||||
*
|
||||
* Uses PortChannel (proxy-based) for ALL peers — both intra-node and inter-node.
|
||||
* This follows the proven pattern from allgather_test_cpp.cu which works reliably
|
||||
* on GB200 multi-node NVSwitch systems.
|
||||
*
|
||||
* For intra-node CudaIpc connections, the proxy performs cudaMemcpyD2D.
|
||||
* For inter-node IB connections, the proxy performs RDMA writes.
|
||||
*
|
||||
* Each block handles one peer. Thread 0 pushes a put descriptor to the FIFO
|
||||
* (single-threaded), which triggers the proxy to perform the data transfer.
|
||||
*
|
||||
* Launch config: <<<nPeers, 1024>>>
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallvPortChannelKernel(PortChannelDeviceHandle* portChannels,
|
||||
int rank,
|
||||
int worldSize,
|
||||
const void* sendBuff,
|
||||
void* recvBuff,
|
||||
const size_t* sendCounts,
|
||||
const size_t* sendDispls,
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls,
|
||||
const size_t* remoteRecvDispls) {
|
||||
const int nPeers = worldSize - 1;
|
||||
|
||||
// Handle trivial case (single rank)
|
||||
if (nPeers == 0) {
|
||||
const int gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int nThreads = blockDim.x * gridDim.x;
|
||||
if (sendCounts[rank] > 0) {
|
||||
mscclpp::copy((char*)recvBuff + recvDispls[rank],
|
||||
(void*)((const char*)sendBuff + sendDispls[rank]),
|
||||
sendCounts[rank], gtid, nThreads);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Phase 1: Local copy — all blocks cooperate using global thread IDs
|
||||
const int gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int nThreads = blockDim.x * gridDim.x;
|
||||
if (sendCounts[rank] > 0) {
|
||||
mscclpp::copy((char*)recvBuff + recvDispls[rank],
|
||||
(void*)((const char*)sendBuff + sendDispls[rank]),
|
||||
sendCounts[rank], gtid, nThreads);
|
||||
}
|
||||
|
||||
// Phase 2: Per-peer data transfer via PortChannel (proxy-based).
|
||||
// Each block handles one peer: blockIdx.x == peerIdx.
|
||||
const int peerIdx = blockIdx.x;
|
||||
if (peerIdx >= nPeers) return;
|
||||
|
||||
const int peer = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
|
||||
// Thread 0 pushes a put+signal+flush descriptor to the proxy FIFO.
|
||||
// The proxy thread performs the actual data transfer (cudaMemcpy or RDMA).
|
||||
if (threadIdx.x == 0 && sendCounts[peer] > 0) {
|
||||
portChannels[peerIdx].putWithSignalAndFlush(
|
||||
remoteRecvDispls[peer], // dst offset in peer's output buffer
|
||||
sendDispls[peer], // src offset in our input buffer
|
||||
sendCounts[peer] // bytes to transfer
|
||||
);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Wait for incoming data from this peer
|
||||
if (threadIdx.x == 0 && recvCounts[peer] > 0) {
|
||||
portChannels[peerIdx].wait();
|
||||
}
|
||||
}
|
||||
|
||||
#undef ALLTOALLV_WARP_SIZE
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
#include <mscclpp/switch_channel.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
@@ -42,6 +43,63 @@ std::vector<MemoryChannel> setupMemoryChannels(
|
||||
const std::vector<RegisteredMemory>& remoteMemories, RegisteredMemory localMemory, int nChannelsPerConnection);
|
||||
|
||||
std::vector<Connection> setupConnections(std::shared_ptr<Communicator> comm);
|
||||
|
||||
/// Setup connections with hybrid transport: CudaIpc for intra-node, IB for inter-node.
|
||||
/// Dynamically detects if all peers are intra-node (single-node case) and falls back to CudaIpc-only.
|
||||
/// @param comm Communicator
|
||||
/// @param localGpuIdx Local GPU index within the node (used to select IB device)
|
||||
/// @return Vector of connections (one per peer)
|
||||
std::vector<Connection> setupHybridConnections(std::shared_ptr<Communicator> comm, int localGpuIdx);
|
||||
|
||||
/// Check if a connection is intra-node (CudaIpc transport).
|
||||
/// @param conn The connection to check
|
||||
/// @return true if the connection uses CudaIpc transport
|
||||
inline bool isIntraNodeConnection(const Connection& conn) {
|
||||
return conn.transport() == Transport::CudaIpc;
|
||||
}
|
||||
|
||||
/// Get the IB transport for a given local GPU index.
|
||||
/// @param localGpuIdx Local GPU index (0-7)
|
||||
/// @return The corresponding IB transport
|
||||
Transport getIBTransportForGpu(int localGpuIdx);
|
||||
|
||||
/// Setup PortChannels for inter-node connections via ProxyService.
|
||||
/// Creates PortChannels only for IB connections, with MemoryId-based addressing.
|
||||
/// @param proxyService The ProxyService managing IB transfers
|
||||
/// @param comm The communicator
|
||||
/// @param connections All connections (mixed CudaIpc + IB)
|
||||
/// @param remoteMemories Remote registered memories (one per peer)
|
||||
/// @param localMemory Local registered memory
|
||||
/// @return Vector of PortChannels (only for IB peers, in connection order)
|
||||
std::vector<PortChannel> setupPortChannels(
|
||||
std::shared_ptr<ProxyService> proxyService,
|
||||
Communicator& comm,
|
||||
const std::vector<Connection>& connections,
|
||||
const std::vector<RegisteredMemory>& remoteMemories,
|
||||
RegisteredMemory localMemory);
|
||||
|
||||
/// Setup PortChannels for ALL connections (both CudaIpc and IB) via ProxyService.
|
||||
/// This follows the proven pattern from allgather_test_cpp.cu:
|
||||
/// - CudaIpc connections: proxy does cudaMemcpyD2D
|
||||
/// - IB connections: proxy does RDMA write
|
||||
/// Creates one PortChannel per peer (dense indexing by peerIdx).
|
||||
/// @param proxyService The ProxyService managing transfers
|
||||
/// @param comm The communicator
|
||||
/// @param connections All connections (mixed CudaIpc + IB)
|
||||
/// @param remoteMemories Remote registered memories (one per peer)
|
||||
/// @param localMemory Local registered memory
|
||||
/// @return Vector of PortChannels (one per peer, in connection order)
|
||||
std::vector<PortChannel> setupAllPortChannels(
|
||||
std::shared_ptr<ProxyService> proxyService,
|
||||
Communicator& comm,
|
||||
const std::vector<Connection>& connections,
|
||||
const std::vector<RegisteredMemory>& remoteMemories,
|
||||
RegisteredMemory localMemory);
|
||||
|
||||
/// Setup PortChannel device handles (GPU-allocated array).
|
||||
std::shared_ptr<PortChannelDeviceHandle> setupPortChannelDeviceHandles(
|
||||
const std::vector<PortChannel>& portChannels);
|
||||
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
|
||||
std::shared_ptr<Communicator> comm, const std::vector<Connection>& connections, int nChannelsPerConnection);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user