Merge branch 'main' into rjsouza/nvls-allgather-pr

This commit is contained in:
Empyreus
2026-06-23 23:27:24 +00:00
7 changed files with 366 additions and 115 deletions

View File

@@ -236,3 +236,46 @@ class AllToAll(Collective):
}
rank_buffers.append(buffers)
return rank_buffers
class SendRecv(Collective):
"""A SendRecv collective communication pattern.
SendRecv performs a point-to-point send/receive operation.
Each rank sends its input buffer to the next rank and receives data from the
previous rank into its output buffer.
This operation creates input and output buffers both sized by chunk_factor,
as each rank sends and receives the same amount of data.
"""
def __init__(self, num_ranks, chunk_factor, inplace):
"""Initialize a new SendRecv collective.
Args:
num_ranks (int): The number of ranks participating in the SendRecv.
chunk_factor (int): The size factor for data chunks.
inplace (bool): Whether the operation should be performed in-place.
Example:
>>> sendrecv = SendRecv(num_ranks=4, chunk_factor=1, inplace=False)
"""
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "sendrecv"
def init_buffers(self):
"""Initialize buffers for the SendRecv operation.
Creates input and output buffers both sized by chunk_factor.
Returns:
list: A list of buffer dictionaries, one for each rank.
"""
rank_buffers = []
for rank in range(self.num_ranks):
buffers = {
BufferType.input: BaseBuffer(rank, BufferType.input, 0, self.chunk_factor),
BufferType.output: BaseBuffer(rank, BufferType.output, 0, self.chunk_factor),
}
rank_buffers.append(buffers)
return rank_buffers

View File

@@ -304,11 +304,24 @@ class BaseBuffer:
self.size = offset + size
def __getitem__(self, key):
if self.offset + key.stop > self.size:
raise RuntimeError(
f"Index range from {self.offset + key.start} - {self.offset + key.stop} is out of bounds for buffer {self.buffer_type}. Buffer size: {self.size}"
if not isinstance(key, slice):
raise TypeError(f"Buffer indices must be slices, not {type(key).__name__}")
if key.step is not None and key.step != 1:
raise ValueError(f"Buffer slicing does not support step != 1 (got step={key.step})")
buffer_size = self.size - self.offset
start = key.start if key.start is not None else 0
stop = key.stop if key.stop is not None else buffer_size
if start < 0 or stop < 0:
raise ValueError(
f"Buffer slicing does not support negative indices (got start={key.start}, stop={key.stop})"
)
return Chunk(self.rank, self.buffer_type, self.offset + key.start, key.stop - key.start)
if start > stop:
raise ValueError(f"Buffer slice start ({start}) must be <= stop ({stop})")
if self.offset + stop > self.size:
raise RuntimeError(
f"Index range from {self.offset + start} - {self.offset + stop} is out of bounds for buffer {self.buffer_type}. Buffer size: {self.size}"
)
return Chunk(self.rank, self.buffer_type, self.offset + start, stop - start)
class Buffer(BaseBuffer):

View File

@@ -0,0 +1,95 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
from mscclpp.language.channel import *
from mscclpp.language.rank import *
from mscclpp.language.general import *
from mscclpp.language.program import *
from mscclpp.language.collectives import *
def send_recv(name, nnodes, gpus_per_node, split_mask, instances):
gpu_size = nnodes * gpus_per_node
group_size = split_mask + 1
if split_mask < 0 or (split_mask & (split_mask + 1)) != 0 or gpu_size % group_size != 0:
raise ValueError(
f"split_mask must be of the form 2^k - 1 and gpu_size ({gpu_size}) must be divisible by "
f"group_size ({group_size}), got split_mask={hex(split_mask)}"
)
collective = SendRecv(gpu_size, 1, False)
with CollectiveProgram(
name,
collective,
gpu_size,
protocol="Simple",
num_threads_per_block=1024,
use_double_scratch_buffer=False,
min_message_size=0,
max_message_size=2**64 - 1,
instances=instances,
):
# Creating separate port channels for next and prev directions.
# When prev and next are the same peer (e.g., 2-node ring), both channels go to the same peer
# and get distinct tags. To ensure cross-rank tag matching (rank A's prev_channel signal
# arrives at rank B's next_channel wait), we create channels in opposite order for the
# "higher" rank so that tags cross-match:
# Lower rank: [next(tag0), prev(tag1)]
# Higher rank: [prev(tag0), next(tag1)]
# Then lower.prev(tag1) == higher.next(tag1) and higher.prev(tag0) == lower.next(tag0)
# When prev != next (3+ nodes), each channel targets a different peer so each gets tag 0
# and this ordering doesn't matter.
group_size = group_size
num_groups = gpu_size // group_size
next_channels = {} # channel for sending to next rank
prev_channels = {} # channel for receiving from prev rank
prev_next_ids = {}
for node in range(nnodes):
for gpu in range(gpus_per_node):
global_rank_id = gpu + gpus_per_node * node
position_in_group = global_rank_id & split_mask
group_id = global_rank_id // group_size
next_group_id = (group_id + 1) % num_groups
next_global_rank_id = next_group_id * group_size + position_in_group
prev_group_id = (group_id - 1 + num_groups) % num_groups
prev_global_rank_id = prev_group_id * group_size + position_in_group
if prev_global_rank_id == next_global_rank_id and global_rank_id > prev_global_rank_id:
# Higher rank: create prev first, then next (swapped order)
prev_channels[global_rank_id] = PortChannel(prev_global_rank_id, global_rank_id)
next_channels[global_rank_id] = PortChannel(next_global_rank_id, global_rank_id)
else:
# Lower rank or different peers: create next first, then prev
next_channels[global_rank_id] = PortChannel(next_global_rank_id, global_rank_id)
prev_channels[global_rank_id] = PortChannel(prev_global_rank_id, global_rank_id)
prev_next_ids[global_rank_id] = (prev_global_rank_id, next_global_rank_id)
# sync with the next rank and the previous rank in the group
for node in range(nnodes):
for gpu in range(gpus_per_node):
global_rank_id = gpu + gpus_per_node * node
prev_global_rank_id, next_global_rank_id = prev_next_ids[global_rank_id]
prev_channels[global_rank_id].signal(tb=0, data_sync=SyncType.none)
next_channels[global_rank_id].wait(tb=0, data_sync=SyncType.after)
src_rank = Rank(global_rank_id)
src_buffer = src_rank.get_input_buffer()
dst_rank = Rank(next_global_rank_id)
dst_buffer = dst_rank.get_output_buffer()
next_channels[global_rank_id].put_with_signal(dst_buffer[:], src_buffer[:], tb=0)
prev_channels[global_rank_id].wait(tb=0, data_sync=SyncType.none)
print(JSON())
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--nnodes", type=int, default=1, help="number of nodes")
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
parser.add_argument("--split_mask", type=lambda x: int(x, 0), default=0x0, help="split mask (e.g. 0x3)")
parser.add_argument("--instances", type=int, default=4, help="number of instances")
args = parser.parse_args()
send_recv(args.name, args.nnodes, args.gpus_per_node, args.split_mask, args.instances)

View File

@@ -14,6 +14,7 @@ from mscclpp import CommGroup, GpuBuffer
from mscclpp.utils import KernelBuilder, pack
import os
import struct
from typing import Callable
import cupy as cp
from mpi4py import MPI
@@ -34,13 +35,13 @@ def parse_dtype(dtype_str):
raise ValueError(f"Unknown data type: {dtype_str}")
def bench_time(n_iters: int, n_graph_iters: int, func):
# capture cuda graph for n_iters of the kernel launch
def bench_time(n_iters: int, n_graph_iters: int, funcs: list[Callable]):
"""Benchmark execution time. `funcs` is a list of callables; iteration i runs funcs[i % len(funcs)]."""
stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(n_iters):
func(stream)
funcs[i % len(funcs)](stream)
graph = stream.end_capture()
# now run a warm up round
@@ -61,15 +62,17 @@ def bench_time(n_iters: int, n_graph_iters: int, func):
def bench_correctness(
collective: str,
input_buf: cp.ndarray,
result_buf: cp.ndarray,
test_buf: cp.ndarray,
input_bufs: list[cp.ndarray],
result_bufs: list[cp.ndarray],
test_bufs: list[cp.ndarray],
dtype_str: str,
rank: int,
num_ranks: int,
n_iters: int,
func,
funcs: list[Callable],
split_mask: int = 0,
):
"""Validate correctness. Buffers and funcs are parallel lists; iteration i uses index i % len(funcs)."""
type_size = cp.dtype(parse_dtype(dtype_str)).itemsize
fill_data_kernel_name = "fill_data_%s" % dtype_str
@@ -79,8 +82,12 @@ def bench_correctness(
coll = "reduce_scatter"
elif "allreduce" in collective:
coll = "all_reduce"
else:
elif "alltoall" in collective:
coll = "all_to_all"
elif "sendrecv" in collective:
coll = "send_recv"
else:
raise ValueError(f"Unknown collective: {collective}")
test_data_kernel_name = "test_data_%s_%s" % (coll, dtype_str)
file_dir = os.path.dirname(os.path.abspath(__file__))
@@ -97,11 +104,20 @@ def bench_correctness(
with stream:
stream.begin_capture()
for i in range(n_iters):
fill_data_params = pack(input_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(rank, i)
idx = i % len(funcs)
cur_input = input_bufs[idx]
cur_result = result_bufs[idx]
cur_test = test_bufs[idx]
fill_data_params = (
pack(cur_input) + struct.pack("Q", cur_input.nbytes // type_size) + pack(rank, i, split_mask)
)
fill_data_kernel.launch_kernel(fill_data_params, nblocks, nthreads, 0, stream)
func(stream)
funcs[idx](stream)
test_data_params = (
pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, rank, i)
pack(cur_result, cur_test)
+ struct.pack("Q", cur_input.nbytes // type_size)
+ pack(num_ranks, rank, i, split_mask)
)
test_data_kernel.launch_kernel(test_data_params, nblocks, nthreads, 0, stream)
graph = stream.end_capture()
@@ -143,10 +159,20 @@ def build_bufs(
rank: int,
num_ranks: int,
):
"""Allocate input/result/test buffers. Returns parallel lists (length 2 for sendrecv double-buffering,
length 1 otherwise) so callers can iterate uniformly."""
type_size = cp.dtype(dtype).itemsize
assert (size % type_size) == 0, "size %d not multiple of type size %d" % (size, type_size)
nelems = size // type_size
# Sendrecv uses double buffering: build two parallel buffer slots.
if "sendrecv" in collective:
n_slots = 2
input_bufs = [GpuBuffer(nelems, dtype=dtype) for _ in range(n_slots)]
result_bufs = [GpuBuffer(nelems, dtype=dtype) for _ in range(n_slots)]
test_bufs = [cp.zeros(nelems, dtype=dtype) for _ in range(n_slots)]
return input_bufs, result_bufs, test_bufs, nelems
if "allgather" in collective:
assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks)
nelems_input = nelems if in_place else nelems // num_ranks
@@ -173,7 +199,7 @@ def build_bufs(
test_buf = cp.zeros(nelems, dtype=dtype)
return input_buf, result_buf, test_buf
return [input_buf], [result_buf], [test_buf], nelems
def main(
@@ -184,8 +210,14 @@ def main(
packet_type: PacketType = PacketType.LL16,
n_iters: int = 10,
n_graph_iters: int = 10,
split_mask: int = 0,
):
mscclpp_group = CommGroup(MPI.COMM_WORLD)
if split_mask < 0 or (split_mask & (split_mask + 1)) != 0 or mscclpp_group.nranks % (split_mask + 1) != 0:
raise ValueError(
f"split_mask must be of the form 2^k - 1 and nranks ({mscclpp_group.nranks}) must be divisible "
f"by group_size ({split_mask + 1}), got split_mask={hex(split_mask)}"
)
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
executor = Executor(mscclpp_group.communicator)
npkit_dump_dir = env().npkit_dump_dir
@@ -195,7 +227,7 @@ def main(
collective = execution_plan.collective
dtype = parse_dtype(dtype_str)
input_buf, result_buf, test_buf = build_bufs(
input_bufs, result_bufs, test_bufs, nelem = build_bufs(
collective,
size,
in_place,
@@ -204,39 +236,48 @@ def main(
mscclpp_group.nranks,
)
executor_func = lambda stream: executor.execute(
mscclpp_group.my_rank,
input_buf.data.ptr,
result_buf.data.ptr,
input_buf.nbytes,
result_buf.nbytes,
dtype_to_mscclpp_dtype(dtype_str),
execution_plan,
stream.ptr,
packet_type,
)
executor_funcs = [
(
lambda stream, inp=inp, res=res: executor.execute(
mscclpp_group.my_rank,
inp.data.ptr,
res.data.ptr,
inp.nbytes,
res.nbytes,
dtype_to_mscclpp_dtype(dtype_str),
execution_plan,
stream.ptr,
packet_type,
)
)
for inp, res in zip(input_bufs, result_bufs)
]
mscclpp_group.barrier()
bench_correctness(
collective,
input_buf,
result_buf,
test_buf,
input_bufs,
result_bufs,
test_bufs,
dtype_str,
mscclpp_group.my_rank,
mscclpp_group.nranks,
n_iters,
executor_func,
executor_funcs,
split_mask=split_mask,
)
mscclpp_group.barrier()
execution_time = bench_time(n_iters, n_graph_iters, executor_func)
execution_time = bench_time(n_iters, n_graph_iters, executor_funcs)
if npkit_dump_dir is not None:
npkit.dump(npkit_dump_dir)
npkit.shutdown()
result_nbytes = result_bufs[0].nbytes
print(
f"Rank: {mscclpp_group.my_rank} Execution time: {execution_time} us, "
f"data size: {result_buf.nbytes} bytes data type: {dtype_str} "
f"data size: {result_nbytes} bytes data type: {dtype_str} "
f"bandwidth: {result_nbytes / (execution_time * 1e-6) / (1024**3):.2f} GB/s, "
f"packet type: {packet_type}"
)
executor = None
@@ -252,6 +293,9 @@ if __name__ == "__main__":
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
parser.add_argument("--n_iters", type=int, default=10)
parser.add_argument("--n_graph_iters", type=int, default=10)
parser.add_argument(
"--split_mask", type=lambda x: int(x, 0), default=0x0, help="split mask for sendrecv (e.g. 0x3)"
)
args = parser.parse_args()
packet_type = PacketType.LL16
@@ -267,4 +311,5 @@ if __name__ == "__main__":
packet_type,
args.n_iters,
args.n_graph_iters,
args.split_mask,
)

View File

@@ -22,14 +22,19 @@ static __device__ unsigned int ranqd1(unsigned int seed) {
// fill/test kernel pairs must have the same thread block size to
// match their random number series.
#define FILL_DATA(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) \
fill_data_##FuncNameType(DataType* input_buf, size_t num_elems, int rank, int seq) { \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
input_buf[i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
} \
// `split_mask` groups ranks together: group_size = split_mask + 1, group_id = rank / group_size.
// Data is seeded by group_id so that all ranks within a group produce the same fill, and ranks
// in different groups produce different fills. With split_mask == 0 this reduces to per-rank
// seeding (group_id == rank).
#define FILL_DATA(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) \
fill_data_##FuncNameType(DataType* input_buf, size_t num_elems, int rank, int seq, int split_mask) { \
int seed_rank = rank / (split_mask + 1); \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + seed_rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
input_buf[i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
} \
}
FILL_DATA(bfloat16, __nv_bfloat16)
@@ -37,18 +42,20 @@ FILL_DATA(float16, __half)
FILL_DATA(float32, float)
FILL_DATA(int32, int)
#define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_gather_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \
for (int rank = 0; rank < num_ranks; rank++) { \
size_t rank_offset = rank * num_elems; \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
test_buf[rank_offset + i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
assert(result_buf[rank_offset + i] == test_buf[rank_offset + i]); \
} \
} \
#define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) \
test_data_all_gather_##FuncNameType(DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, \
int my_rank, int seq, int split_mask) { \
for (int rank = 0; rank < num_ranks; rank++) { \
size_t rank_offset = rank * num_elems; \
int seed_rank = rank / (split_mask + 1); \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + seed_rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
test_buf[rank_offset + i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
assert(result_buf[rank_offset + i] == test_buf[rank_offset + i]); \
} \
} \
}
TEST_DATA_ALL_GATHER(bfloat16, __nv_bfloat16)
@@ -56,25 +63,27 @@ TEST_DATA_ALL_GATHER(float16, __half)
TEST_DATA_ALL_GATHER(float32, float)
TEST_DATA_ALL_GATHER(int32, int)
#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType, Eps) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_reduce_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \
for (int rank = 0; rank < num_ranks; rank++) { \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
if (rank == 0) { \
test_buf[i] = 0; \
} \
seed = ranqd1(seed); \
test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \
} \
} \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
float expected = float(test_buf[i]); \
float result = float(result_buf[i]); \
float tol = Eps * num_ranks * (1.0f + abs(expected)); \
assert(abs(result - expected) <= tol); \
} \
#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType, Eps) \
extern "C" __global__ void __launch_bounds__(1024, 1) \
test_data_all_reduce_##FuncNameType(DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, \
int my_rank, int seq, int split_mask) { \
for (int rank = 0; rank < num_ranks; rank++) { \
int seed_rank = rank / (split_mask + 1); \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + seed_rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
if (rank == 0) { \
test_buf[i] = 0; \
} \
seed = ranqd1(seed); \
test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \
} \
} \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
float expected = float(test_buf[i]); \
float result = float(result_buf[i]); \
float tol = Eps * num_ranks * (1.0f + abs(expected)); \
assert(abs(result - expected) <= tol); \
} \
}
TEST_DATA_ALL_REDUCE(bfloat16, __nv_bfloat16, 7.8125e-3f)
@@ -83,12 +92,14 @@ TEST_DATA_ALL_REDUCE(float32, float, 1.1920929e-7f)
TEST_DATA_ALL_REDUCE(int32, int, 0.0f)
#define TEST_DATA_REDUCE_SCATTER(FuncNameType, DataType, Eps) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_reduce_scatter_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \
extern "C" __global__ void __launch_bounds__(1024, 1) \
test_data_reduce_scatter_##FuncNameType(DataType* result_buf, DataType* test_buf, size_t num_elems, \
int num_ranks, int my_rank, int seq, int split_mask) { \
int nem_elems_per_rank = num_elems / num_ranks; \
int offset = nem_elems_per_rank * my_rank; \
for (int rank = 0; rank < num_ranks; rank++) { \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
int seed_rank = rank / (split_mask + 1); \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + seed_rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
if (rank == 0) { \
test_buf[i] = 0; \
@@ -112,25 +123,51 @@ TEST_DATA_REDUCE_SCATTER(float16, __half, 9.765625e-4f)
TEST_DATA_REDUCE_SCATTER(float32, float, 1.1920929e-7f)
TEST_DATA_REDUCE_SCATTER(int32, int, 0.0f)
#define TEST_DATA_ALL_TO_ALL(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_to_all_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \
int nem_elems_per_rank = num_elems / num_ranks; \
int offset = nem_elems_per_rank * my_rank; \
for (int rank = 0; rank < num_ranks; rank++) { \
size_t rank_offset = rank * nem_elems_per_rank; \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
if (i >= my_rank * nem_elems_per_rank && i < (my_rank + 1) * nem_elems_per_rank) { \
test_buf[rank_offset + i - offset] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
assert(result_buf[rank_offset + i - offset] == test_buf[rank_offset + i - offset]); \
} \
} \
} \
#define TEST_DATA_ALL_TO_ALL(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) \
test_data_all_to_all_##FuncNameType(DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, \
int my_rank, int seq, int split_mask) { \
int nem_elems_per_rank = num_elems / num_ranks; \
int offset = nem_elems_per_rank * my_rank; \
for (int rank = 0; rank < num_ranks; rank++) { \
size_t rank_offset = rank * nem_elems_per_rank; \
int seed_rank = rank / (split_mask + 1); \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + seed_rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
if (i >= my_rank * nem_elems_per_rank && i < (my_rank + 1) * nem_elems_per_rank) { \
test_buf[rank_offset + i - offset] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
assert(result_buf[rank_offset + i - offset] == test_buf[rank_offset + i - offset]); \
} \
} \
} \
}
TEST_DATA_ALL_TO_ALL(bfloat16, __nv_bfloat16)
TEST_DATA_ALL_TO_ALL(float16, __half)
TEST_DATA_ALL_TO_ALL(float32, float)
TEST_DATA_ALL_TO_ALL(int32, int)
TEST_DATA_ALL_TO_ALL(int32, int)
// Sendrecv verification: receive from the prev group in the ring.
// fill_data seeds by group_id (rank / (split_mask + 1)); the receiver in group g expects the
// data produced by group (g - 1 + num_groups) % num_groups, so we recompute that seed here.
#define TEST_DATA_SEND_RECV(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) \
test_data_send_recv_##FuncNameType(DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, \
int my_rank, int seq, int split_mask) { \
int group_size = split_mask + 1; \
int num_groups = num_ranks / group_size; \
int my_group_id = my_rank / group_size; \
int prev_group_id = (my_group_id - 1 + num_groups) % num_groups; \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + prev_group_id + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
test_buf[i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
assert(result_buf[i] == test_buf[i]); \
} \
}
TEST_DATA_SEND_RECV(bfloat16, __nv_bfloat16)
TEST_DATA_SEND_RECV(float16, __half)
TEST_DATA_SEND_RECV(float32, float)
TEST_DATA_SEND_RECV(int32, int)

View File

@@ -95,6 +95,7 @@ struct hash<mscclpp::DeviceExecutionPlanKey> {
namespace {
auto hasIBDevices = []() { return mscclpp::getIBDeviceCount() > 0; };
// TODO(binyli): Need to add NVL domain check.
auto useIB = [](int rank1, int rank2, int nranksPerNode) {
if (mscclpp::env()->forceDisableIb) return false;
bool inSameNode = rank1 / nranksPerNode == rank2 / nranksPerNode;
@@ -110,7 +111,7 @@ namespace mscclpp {
struct ExecutionContext {
std::shared_ptr<ProxyService> proxyService;
std::unordered_map<int, Connection> connections;
std::vector<Connection> connections;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
MemoryId localMemoryIdBegin = MemoryId(0);
@@ -122,8 +123,6 @@ struct ExecutionContext {
// local registered memories to keep resources alive
std::vector<mscclpp::RegisteredMemory> localRegisteredMemories;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
std::vector<mscclpp::BaseMemoryChannel> memoryChannels;
std::vector<mscclpp::BasePortChannel> portChannels;
std::vector<mscclpp::SwitchChannel> nvlsChannels;
@@ -267,15 +266,36 @@ struct Executor::Impl {
}
};
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers();
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
for (int peer : connectedPeers) {
Transport transport =
!useIB(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
connectionFutures.push_back(this->comm->connect(transport, peer));
// Create one connection (unique QP) per channel entry. Each channel gets its own
// QP — no shared connections.
// Use per-peer tag counters so that matched connections between pairs of ranks use
// the same tag, regardless of the order peers appear in each rank's connected_to list.
std::unordered_map<int, int> peerTagCounters;
Transport ibTransport = IBs[rank % this->nranksPerNode];
std::vector<std::shared_future<Connection>> connFutures;
for (ChannelType channelType : {ChannelType::MEMORY, ChannelType::PORT}) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
for (const auto& info : channelInfos) {
for (int peer : info.connectedPeers) {
Transport transport = channelType == ChannelType::PORT && useIB(rank, peer, this->nranksPerNode)
? ibTransport
: Transport::CudaIpc;
connFutures.push_back(this->comm->connect(transport, peer, peerTagCounters[peer]++));
}
}
channelInfos = plan.impl_->getUnpairedChannelInfos(nranks, channelType);
for (const auto& info : channelInfos) {
for (int peer : info.connectedPeers) {
Transport transport = channelType == ChannelType::PORT && useIB(rank, peer, this->nranksPerNode)
? ibTransport
: Transport::CudaIpc;
connFutures.push_back(this->comm->connect(transport, peer, peerTagCounters[peer]++));
}
}
}
for (size_t i = 0; i < connectionFutures.size(); i++) {
context.connections[connectedPeers[i]] = connectionFutures[i].get();
for (auto& future : connFutures) {
context.connections.push_back(future.get());
}
std::vector<NvlsInfo> nvlsInfos = plan.impl_->nvlsInfos.at(rank);
@@ -329,10 +349,11 @@ struct Executor::Impl {
std::vector<std::shared_future<Semaphore>> futureProxySemaphores;
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
int connIdx = 0;
auto processChannelInfos = [&](std::vector<ChannelInfo>& channelInfos) {
for (ChannelInfo& info : channelInfos) {
for (int peer : info.connectedPeers) {
auto connection = context.connections.at(peer);
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
auto& connection = context.connections[connIdx++];
if (info.channelType == ChannelType::MEMORY) {
futureMemorySemaphores.push_back(this->comm->buildSemaphore(
connection, this->comm->remoteRankOf(connection), this->comm->tagOf(connection)));
@@ -361,18 +382,15 @@ struct Executor::Impl {
proxySemaphores.push_back(context.proxyService->addSemaphore(sem.get()));
}
context.memorySemaphores = std::move(memorySemaphores);
context.proxySemaphores = std::move(proxySemaphores);
for (ChannelType channelType : channelTypes) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
int index = 0;
for (ChannelInfo& info : channelInfos) {
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
if (channelType == ChannelType::MEMORY) {
context.memoryChannels.emplace_back(context.memorySemaphores[index++]);
context.memoryChannels.emplace_back(memorySemaphores[index++]);
} else if (channelType == ChannelType::PORT) {
context.portChannels.emplace_back(context.proxyService->basePortChannel(context.proxySemaphores[index++]));
context.portChannels.emplace_back(context.proxyService->basePortChannel(proxySemaphores[index++]));
}
}
}

View File

@@ -174,11 +174,11 @@ MSCCLPP_DEVICE_INLINE void handlePut(const Operation& op, void* input, void* out
uint32_t dstOffset =
dstOffsets[tid] + getOffset<ReuseScratch>(portChannelBufferTypes_[op.outputBufferRefs[tid].id], offset);
uint32_t srcOffset = srcOffsets[tid] + getOffset<ReuseScratch>(op.inputBufferRefs[tid].type, offset);
if constexpr (PutWithSignal) {
portChannels_[channelIndexes[tid]].putWithSignal(dstMemoryId, dstOffset, srcMemoryId, srcOffset, size);
} else if constexpr (PutWithSignalAndFlush) {
if constexpr (PutWithSignalAndFlush) {
portChannels_[channelIndexes[tid]].putWithSignalAndFlush(dstMemoryId, (uint64_t)dstOffset, srcMemoryId,
(uint64_t)srcOffsets, size);
(uint64_t)srcOffset, size);
} else if constexpr (PutWithSignal) {
portChannels_[channelIndexes[tid]].putWithSignal(dstMemoryId, dstOffset, srcMemoryId, srcOffset, size);
} else {
portChannels_[channelIndexes[tid]].put(dstMemoryId, dstOffset, srcMemoryId, srcOffset, size);
}